@@ -139,7 +139,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
139139 // the fetch failure. The executor should still tell the driver that the task failed due to a
140140 // fetch failure, not a generic exception from user code.
141141 val inputRDD = new FetchFailureThrowingRDD (sc)
142- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false )
142+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false , interrupt = false )
143143 val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
144144 val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
145145 val task = new ResultTask (
@@ -173,8 +173,26 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
173173 }
174174
175175 test(" SPARK-19276: OOMs correctly handled with a FetchFailure" ) {
176+ val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true )
177+ assert(failReason.isInstanceOf [ExceptionFailure ])
178+ val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
179+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
180+ assert(exceptionCaptor.getAllValues.size === 1 )
181+ assert(exceptionCaptor.getAllValues().get(0 ).isInstanceOf [OutOfMemoryError ])
182+ }
183+
184+ test(s " SPARK-23816: interrupts are not masked by a FetchFailure " ) {
185+ // If killing the task causes a fetch failure, we still treat it as a task that was killed,
186+ // as the fetch failure could easily be caused by interrupting the thread.
187+ val (failReason, _) = testFetchFailureHandling(false )
188+ assert(failReason.isInstanceOf [TaskKilled ])
189+ }
190+
191+ def testFetchFailureHandling (oom : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
176192 // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
177193 // may be a false positive. And we should call the uncaught exception handler.
194+ // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
195+ // does not represent a real fetch failure.
178196 val conf = new SparkConf ().setMaster(" local" ).setAppName(" executor suite test" )
179197 sc = new SparkContext (conf)
180198 val serializer = SparkEnv .get.closureSerializer.newInstance()
@@ -183,7 +201,13 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
183201 // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
184202 // the fetch failure as a false positive, and just do normal OOM handling.
185203 val inputRDD = new FetchFailureThrowingRDD (sc)
186- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = true )
204+ if (! oom) {
205+ // we are trying to setup a case where a task is killed after a fetch failure -- this
206+ // is just a helper to coordinate between the task thread and this thread that will
207+ // kill the task
208+ ExecutorSuiteHelper .latches = new ExecutorSuiteHelper ()
209+ }
210+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = oom, interrupt = ! oom)
187211 val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
188212 val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
189213 val task = new ResultTask (
@@ -200,15 +224,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
200224 val serTask = serializer.serialize(task)
201225 val taskDescription = createFakeTaskDescription(serTask)
202226
203- val (failReason, uncaughtExceptionHandler) =
204- runTaskGetFailReasonAndExceptionHandler(taskDescription)
205- // make sure the task failure just looks like a OOM, not a fetch failure
206- assert(failReason.isInstanceOf [ExceptionFailure ])
207- val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
208- verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
209- assert(exceptionCaptor.getAllValues.size === 1 )
210- assert(exceptionCaptor.getAllValues.get(0 ).isInstanceOf [OutOfMemoryError ])
211- }
227+ runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = ! oom)
228+ }
212229
213230 test(" Gracefully handle error in task deserialization" ) {
214231 val conf = new SparkConf
@@ -257,19 +274,32 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
257274 }
258275
259276 private def runTaskAndGetFailReason (taskDescription : TaskDescription ): TaskFailedReason = {
260- runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
277+ runTaskGetFailReasonAndExceptionHandler(taskDescription, false )._1
261278 }
262279
263280 private def runTaskGetFailReasonAndExceptionHandler (
264- taskDescription : TaskDescription ): (TaskFailedReason , UncaughtExceptionHandler ) = {
281+ taskDescription : TaskDescription ,
282+ killTask : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
265283 val mockBackend = mock[ExecutorBackend ]
266284 val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler ]
267285 var executor : Executor = null
286+ var killingThread : Thread = null
268287 try {
269288 executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
270289 uncaughtExceptionHandler = mockUncaughtExceptionHandler)
271290 // the task will be launched in a dedicated worker thread
272291 executor.launchTask(mockBackend, taskDescription)
292+ if (killTask) {
293+ killingThread = new Thread (" kill-task" ) {
294+ override def run (): Unit = {
295+ // wait to kill the task until it has thrown a fetch failure
296+ ExecutorSuiteHelper .latches.latch1.await()
297+ // now we can kill the task
298+ executor.killAllTasks(true , " Killed task, eg. because of speculative execution" )
299+ }
300+ }
301+ killingThread.start()
302+ }
273303 eventually(timeout(5 .seconds), interval(10 .milliseconds)) {
274304 assert(executor.numRunningTasks === 0 )
275305 }
@@ -282,8 +312,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
282312 val statusCaptor = ArgumentCaptor .forClass(classOf [ByteBuffer ])
283313 orderedMock.verify(mockBackend)
284314 .statusUpdate(meq(0L ), meq(TaskState .RUNNING ), statusCaptor.capture())
315+ val finalState = if (killTask) TaskState .KILLED else TaskState .FAILED
285316 orderedMock.verify(mockBackend)
286- .statusUpdate(meq(0L ), meq(TaskState . FAILED ), statusCaptor.capture())
317+ .statusUpdate(meq(0L ), meq(finalState ), statusCaptor.capture())
287318 // first statusUpdate for RUNNING has empty data
288319 assert(statusCaptor.getAllValues().get(0 ).remaining() === 0 )
289320 // second update is more interesting
@@ -321,7 +352,8 @@ class SimplePartition extends Partition {
321352class FetchFailureHidingRDD (
322353 sc : SparkContext ,
323354 val input : FetchFailureThrowingRDD ,
324- throwOOM : Boolean ) extends RDD [Int ](input) {
355+ throwOOM : Boolean ,
356+ interrupt : Boolean ) extends RDD [Int ](input) {
325357 override def compute (split : Partition , context : TaskContext ): Iterator [Int ] = {
326358 val inItr = input.compute(split, context)
327359 try {
@@ -330,6 +362,15 @@ class FetchFailureHidingRDD(
330362 case t : Throwable =>
331363 if (throwOOM) {
332364 throw new OutOfMemoryError (" OOM while handling another exception" )
365+ } else if (interrupt) {
366+ // make sure our test is setup correctly
367+ assert(TaskContext .get().asInstanceOf [TaskContextImpl ].fetchFailed.isDefined)
368+ // signal our test is ready for the task to get killed
369+ ExecutorSuiteHelper .latches.latch1.countDown()
370+ // then wait for another thread in the test to kill the task -- this latch
371+ // is never actually decremented, we just wait to get killed.
372+ ExecutorSuiteHelper .latches.latch2.await()
373+ throw new IllegalStateException (" impossible" )
333374 } else {
334375 throw new RuntimeException (" User Exception that hides the original exception" , t)
335376 }
@@ -352,6 +393,11 @@ private class ExecutorSuiteHelper {
352393 @ volatile var testFailedReason : TaskFailedReason = _
353394}
354395
396+ // helper for coordinating killing tasks
397+ private object ExecutorSuiteHelper {
398+ var latches : ExecutorSuiteHelper = null
399+ }
400+
355401private class NonDeserializableTask extends FakeTask (0 , 0 ) with Externalizable {
356402 def writeExternal (out : ObjectOutput ): Unit = {}
357403 def readExternal (in : ObjectInput ): Unit = {
0 commit comments