3939import com .google .common .collect .ImmutableList ;
4040import com .google .common .util .concurrent .MoreExecutors ;
4141import com .google .protobuf .AbstractMessage ;
42+ import com .google .protobuf .ByteString ;
4243import com .google .protobuf .ListValue ;
4344import com .google .spanner .v1 .BeginTransactionRequest ;
4445import com .google .spanner .v1 .CommitRequest ;
@@ -123,6 +124,20 @@ public static Collection<Object[]> data() {
123124 .build ())
124125 .setMetadata (SELECT1_METADATA )
125126 .build ();
127+ private static final Statement SELECT1_UNION_ALL_SELECT2 =
128+ Statement .of ("SELECT 1 AS COL1 UNION ALL SELECT 2 AS COL1" );
129+ private static final com .google .spanner .v1 .ResultSet SELECT1_UNION_ALL_SELECT2_RESULTSET =
130+ com .google .spanner .v1 .ResultSet .newBuilder ()
131+ .addRows (
132+ ListValue .newBuilder ()
133+ .addValues (com .google .protobuf .Value .newBuilder ().setStringValue ("1" ).build ())
134+ .build ())
135+ .addRows (
136+ ListValue .newBuilder ()
137+ .addValues (com .google .protobuf .Value .newBuilder ().setStringValue ("2" ).build ())
138+ .build ())
139+ .setMetadata (SELECT1_METADATA )
140+ .build ();
126141 private static final Statement INVALID_SELECT = Statement .of ("SELECT * FROM NON_EXISTING_TABLE" );
127142 private static final Statement READ_STATEMENT = Statement .of ("SELECT ID FROM FOO WHERE 1=1" );
128143
@@ -134,6 +149,8 @@ public static void startStaticServer() throws IOException {
134149 mockSpanner .setAbortProbability (0.0D ); // We don't want any unpredictable aborted transactions.
135150 mockSpanner .putStatementResult (StatementResult .update (UPDATE_STATEMENT , UPDATE_COUNT ));
136151 mockSpanner .putStatementResult (StatementResult .query (SELECT1 , SELECT1_RESULTSET ));
152+ mockSpanner .putStatementResult (
153+ StatementResult .query (SELECT1_UNION_ALL_SELECT2 , SELECT1_UNION_ALL_SELECT2_RESULTSET ));
137154 mockSpanner .putStatementResult (StatementResult .query (READ_STATEMENT , SELECT1_RESULTSET ));
138155 mockSpanner .putStatementResult (
139156 StatementResult .exception (
@@ -1257,6 +1274,45 @@ public Long run(TransactionContext transaction) throws Exception {
12571274 assertThat (((ExecuteSqlRequest ) requests .get (1 )).getSql ()).isEqualTo (SELECT1 .getSql ());
12581275 }
12591276
1277+ @ Test
1278+ public void testInlinedBeginTxWithStreamRetry () {
1279+ mockSpanner .setExecuteStreamingSqlExecutionTime (
1280+ SimulatedExecutionTime .ofStreamException (Status .UNAVAILABLE .asRuntimeException (), 1 ));
1281+
1282+ DatabaseClient client = spanner .getDatabaseClient (DatabaseId .of ("p" , "i" , "d" ));
1283+ client
1284+ .readWriteTransaction ()
1285+ .run (
1286+ new TransactionCallable <Void >() {
1287+ @ Override
1288+ public Void run (TransactionContext transaction ) throws Exception {
1289+ try (ResultSet rs = transaction .executeQuery (SELECT1_UNION_ALL_SELECT2 )) {
1290+ while (rs .next ()) {}
1291+ }
1292+ return null ;
1293+ }
1294+ });
1295+ assertThat (countRequests (BeginTransactionRequest .class )).isEqualTo (0 );
1296+ assertThat (countRequests (ExecuteSqlRequest .class )).isEqualTo (2 );
1297+ assertThat (countRequests (CommitRequest .class )).isEqualTo (1 );
1298+ assertThat (countTransactionsStarted ()).isEqualTo (1 );
1299+
1300+ List <AbstractMessage > requests = mockSpanner .getRequestsOfType (ExecuteSqlRequest .class );
1301+ assertThat (requests .get (0 )).isInstanceOf (ExecuteSqlRequest .class );
1302+ ExecuteSqlRequest request1 = (ExecuteSqlRequest ) requests .get (0 );
1303+ assertThat (request1 .getSql ()).isEqualTo (SELECT1_UNION_ALL_SELECT2 .getSql ());
1304+ assertThat (request1 .getTransaction ().getBegin ().hasReadWrite ()).isTrue ();
1305+ assertThat (request1 .getTransaction ().getId ()).isEqualTo (ByteString .EMPTY );
1306+ assertThat (request1 .getResumeToken ()).isEqualTo (ByteString .EMPTY );
1307+
1308+ assertThat (requests .get (1 )).isInstanceOf (ExecuteSqlRequest .class );
1309+ ExecuteSqlRequest request2 = (ExecuteSqlRequest ) requests .get (1 );
1310+ assertThat (request2 .getSql ()).isEqualTo (SELECT1_UNION_ALL_SELECT2 .getSql ());
1311+ assertThat (request2 .getTransaction ().hasBegin ()).isFalse ();
1312+ assertThat (request2 .getTransaction ().getId ()).isNotEqualTo (ByteString .EMPTY );
1313+ assertThat (request2 .getResumeToken ()).isNotEqualTo (ByteString .EMPTY );
1314+ }
1315+
12601316 private int countRequests (Class <? extends AbstractMessage > requestType ) {
12611317 int count = 0 ;
12621318 for (AbstractMessage msg : mockSpanner .getRequests ()) {
0 commit comments