@@ -19,26 +19,17 @@ package org.apache.spark.sql.streaming
1919
2020import java .util .TimeZone
2121
22- import scala .collection .mutable
23- import scala .reflect .runtime .{universe => ru }
24-
25- import org .apache .hadoop .conf .Configuration
26- import org .mockito .Mockito
27- import org .mockito .invocation .InvocationOnMock
28- import org .mockito .stubbing .Answer
2922import org .scalatest .BeforeAndAfterAll
30- import org .scalatest .PrivateMethodTester ._
3123
3224import org .apache .spark .SparkException
3325import org .apache .spark .sql .AnalysisException
34- import org .apache .spark .sql .catalyst .util ._
26+ import org .apache .spark .sql .catalyst .util .DateTimeUtils
3527import org .apache .spark .sql .execution .SparkPlan
3628import org .apache .spark .sql .execution .streaming ._
37- import org .apache .spark .sql .execution .streaming .state ._
29+ import org .apache .spark .sql .execution .streaming .state .StateStore
3830import org .apache .spark .sql .expressions .scalalang .typed
3931import org .apache .spark .sql .functions ._
4032import org .apache .spark .sql .streaming .OutputMode ._
41- import org .apache .spark .sql .types ._
4233
4334object FailureSinglton {
4435 var firstTime = true
@@ -344,67 +335,4 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
344335 CheckLastBatch ((90L , 1 ), (100L , 1 ), (105L , 1 ))
345336 )
346337 }
347-
348- test(" abort StateStore in case of error" ) {
349- quietly {
350- val inputData = MemoryStream [Long ]
351- val aggregated =
352- inputData.toDS()
353- .groupBy($" value" )
354- .agg(count(" *" ))
355- var aborted = false
356- testStream(aggregated, Complete )(
357- // This whole `AssertOnQuery` is used to inject a mock state store
358- AssertOnQuery (execution => {
359- // (1) Use reflection to get `StateStore.loadedProviders`
360- val loadedProviders = {
361- val field = ru.typeOf[StateStore .type ].decl(ru.TermName (" loadedProviders" )).asTerm
362- ru.runtimeMirror(StateStore .getClass.getClassLoader)
363- .reflect(StateStore )
364- .reflectField(field)
365- .get
366- .asInstanceOf [mutable.HashMap [StateStoreId , StateStoreProvider ]]
367- }
368- // (2) Make a storeId
369- val storeId = {
370- val checkpointLocation =
371- execution invokePrivate PrivateMethod [String ](' checkpointFile )(" state" )
372- StateStoreId (checkpointLocation, 0L , 0 )
373- }
374- // (3) Make `mockStore` and `mockProvider`
375- val (mockStore, mockProvider) = {
376- val keySchema = StructType (Seq (
377- StructField (" value" , LongType , false )))
378- val valueSchema = StructType (Seq (
379- StructField (" value" , LongType , false ), StructField (" count" , LongType , false )))
380- val storeConf = StateStoreConf .empty
381- val hadoopConf = new Configuration
382- (Mockito .spy(
383- StateStore .get(storeId, keySchema, valueSchema, version = 0 , storeConf, hadoopConf)),
384- Mockito .spy(loadedProviders.get(storeId).get))
385- }
386- // (4) Setup `mockStore` and `mockProvider`
387- Mockito .doAnswer(new Answer [Long ] {
388- override def answer (invocationOnMock : InvocationOnMock ): Long = {
389- sys.error(" injected error on commit()" )
390- }
391- }).when(mockStore).commit()
392- Mockito .doAnswer(new Answer [Unit ] {
393- override def answer (invocationOnMock : InvocationOnMock ): Unit = {
394- invocationOnMock.callRealMethod()
395- // Mark the flag for later check
396- aborted = true
397- }
398- }).when(mockStore).abort()
399- Mockito .doReturn(mockStore).when(mockProvider).getStore(version = 0 )
400- // (5) Inject `mockProvider`, which later on would inject `mockStore`
401- loadedProviders.put(storeId, mockProvider)
402- true
403- }), // End of AssertOnQuery, i.e. end of injecting `mockStore`
404- AddData (inputData, 1L , 2L , 3L ),
405- ExpectFailure [SparkException ](),
406- AssertOnQuery { _ => aborted } // Check that `mockStore.abort()` is called upon error
407- )
408- }
409- }
410338}
0 commit comments