|
27 | 27 |
|
28 | 28 | import com.google.api.client.util.BackOff; |
29 | 29 | import com.google.cloud.spanner.TransactionRunner.TransactionCallable; |
| 30 | +import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; |
30 | 31 | import com.google.cloud.spanner.spi.v1.SpannerRpc; |
| 32 | +import com.google.common.base.Preconditions; |
| 33 | +import com.google.protobuf.ByteString; |
| 34 | +import com.google.protobuf.Timestamp; |
| 35 | +import com.google.rpc.Code; |
| 36 | +import com.google.spanner.v1.CommitRequest; |
| 37 | +import com.google.spanner.v1.CommitResponse; |
| 38 | +import com.google.spanner.v1.ExecuteBatchDmlRequest; |
| 39 | +import com.google.spanner.v1.ExecuteBatchDmlResponse; |
| 40 | +import com.google.spanner.v1.ResultSet; |
| 41 | +import com.google.spanner.v1.ResultSetStats; |
31 | 42 | import io.grpc.Context; |
32 | 43 | import io.grpc.Status; |
33 | 44 | import io.grpc.StatusRuntimeException; |
| 45 | +import java.util.Arrays; |
34 | 46 | import java.util.concurrent.atomic.AtomicInteger; |
35 | 47 | import org.junit.Before; |
36 | 48 | import org.junit.Test; |
37 | 49 | import org.junit.runner.RunWith; |
38 | 50 | import org.junit.runners.JUnit4; |
39 | 51 | import org.mockito.Mock; |
| 52 | +import org.mockito.Mockito; |
40 | 53 | import org.mockito.MockitoAnnotations; |
41 | 54 |
|
42 | 55 | /** Unit test for {@link com.google.cloud.spanner.SpannerImpl.TransactionRunnerImpl} */ |
@@ -141,6 +154,77 @@ public void runResourceExhaustedNoRetry() throws Exception { |
141 | 154 | verify(txn).rollback(); |
142 | 155 | } |
143 | 156 |
|
| 157 | + @Test |
| 158 | + public void batchDmlAborted() { |
| 159 | + long updateCount[] = batchDmlException(Code.ABORTED_VALUE); |
| 160 | + assertThat(updateCount.length).isEqualTo(2); |
| 161 | + assertThat(updateCount[0]).isEqualTo(1L); |
| 162 | + assertThat(updateCount[1]).isEqualTo(1L); |
| 163 | + } |
| 164 | + |
| 165 | + @Test |
| 166 | + public void batchDmlFailedPrecondition() { |
| 167 | + try { |
| 168 | + batchDmlException(Code.FAILED_PRECONDITION_VALUE); |
| 169 | + fail("Expected exception"); |
| 170 | + } catch (SpannerBatchUpdateException e) { |
| 171 | + assertThat(e.getUpdateCounts().length).isEqualTo(1); |
| 172 | + assertThat(e.getUpdateCounts()[0]).isEqualTo(1L); |
| 173 | + assertThat(e.getCode() == Code.FAILED_PRECONDITION_VALUE); |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + @SuppressWarnings("unchecked") |
| 178 | + private long[] batchDmlException(int status) { |
| 179 | + Preconditions.checkArgument(status != Code.OK_VALUE); |
| 180 | + TransactionContextImpl transaction = |
| 181 | + new TransactionContextImpl(session, ByteString.copyFromUtf8("test"), rpc, 10); |
| 182 | + when(session.newTransaction()).thenReturn(transaction); |
| 183 | + when(session.getName()).thenReturn("test"); |
| 184 | + TransactionRunnerImpl runner = new TransactionRunnerImpl(session, rpc, 10); |
| 185 | + ExecuteBatchDmlResponse response1 = |
| 186 | + ExecuteBatchDmlResponse.newBuilder() |
| 187 | + .addResultSets( |
| 188 | + ResultSet.newBuilder() |
| 189 | + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L)) |
| 190 | + .build()) |
| 191 | + .setStatus(com.google.rpc.Status.newBuilder().setCode(status).build()) |
| 192 | + .build(); |
| 193 | + ExecuteBatchDmlResponse response2 = |
| 194 | + ExecuteBatchDmlResponse.newBuilder() |
| 195 | + .addResultSets( |
| 196 | + ResultSet.newBuilder() |
| 197 | + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L)) |
| 198 | + .build()) |
| 199 | + .addResultSets( |
| 200 | + ResultSet.newBuilder() |
| 201 | + .setStats(ResultSetStats.newBuilder().setRowCountExact(1L)) |
| 202 | + .build()) |
| 203 | + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()) |
| 204 | + .build(); |
| 205 | + when(rpc.executeBatchDml(Mockito.any(ExecuteBatchDmlRequest.class), Mockito.anyMap())) |
| 206 | + .thenReturn(response1, response2); |
| 207 | + CommitResponse commitResponse = |
| 208 | + CommitResponse.newBuilder().setCommitTimestamp(Timestamp.getDefaultInstance()).build(); |
| 209 | + when(rpc.commit(Mockito.any(CommitRequest.class), Mockito.anyMap())).thenReturn(commitResponse); |
| 210 | + final Statement statement = Statement.of("UPDATE FOO SET BAR=1"); |
| 211 | + final AtomicInteger numCalls = new AtomicInteger(0); |
| 212 | + long updateCount[] = |
| 213 | + runner.run( |
| 214 | + new TransactionCallable<long[]>() { |
| 215 | + @Override |
| 216 | + public long[] run(TransactionContext transaction) throws Exception { |
| 217 | + numCalls.incrementAndGet(); |
| 218 | + return transaction.batchUpdate(Arrays.asList(statement, statement)); |
| 219 | + } |
| 220 | + }); |
| 221 | + if (status == Code.ABORTED_VALUE) { |
| 222 | + // Assert that the method ran twice because the first response aborted. |
| 223 | + assertThat(numCalls.get()).isEqualTo(2); |
| 224 | + } |
| 225 | + return updateCount; |
| 226 | + } |
| 227 | + |
144 | 228 | private void runTransaction(final Exception exception) { |
145 | 229 | transactionRunner.run( |
146 | 230 | new TransactionCallable<Void>() { |
|
0 commit comments