@@ -713,10 +713,10 @@ class DAGScheduler(
713713 // cancelling the stages because if the DAG scheduler is stopped, the entire application
714714 // is in the process of getting stopped.
715715 val stageFailedMessage = " Stage cancelled because SparkContext was shut down"
716- runningStages.foreach { stage =>
717- stage.latestInfo.stageFailed(stageFailedMessage)
718- outputCommitCoordinator.stageEnd( stage.id)
719- listenerBus.post( SparkListenerStageCompleted ( stage.latestInfo ))
716+ // The `toArray` here is necessary so that we don't iterate over `runningStages` while
717+ // mutating it.
718+ runningStages.toArray.foreach { stage =>
719+ markStageAsFinished( stage, Some (stageFailedMessage ))
720720 }
721721 listenerBus.post(SparkListenerJobEnd (job.jobId, clock.getTimeMillis(), JobFailed (error)))
722722 }
@@ -891,10 +891,9 @@ class DAGScheduler(
891891 new TaskSet (tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
892892 stage.latestInfo.submissionTime = Some (clock.getTimeMillis())
893893 } else {
894- // Because we posted SparkListenerStageSubmitted earlier, we should post
895- // SparkListenerStageCompleted here in case there are no tasks to run.
896- outputCommitCoordinator.stageEnd(stage.id)
897- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
894+ // Because we posted SparkListenerStageSubmitted earlier, we should mark
895+ // the stage as completed here in case there are no tasks to run
896+ markStageAsFinished(stage, None )
898897
899898 val debugString = stage match {
900899 case stage : ShuffleMapStage =>
@@ -906,7 +905,6 @@ class DAGScheduler(
906905 s " Stage ${stage} is actually done; (partitions: ${stage.numPartitions}) "
907906 }
908907 logDebug(debugString)
909- runningStages -= stage
910908 }
911909 }
912910
@@ -972,23 +970,6 @@ class DAGScheduler(
972970 }
973971
974972 val stage = stageIdToStage(task.stageId)
975-
976- def markStageAsFinished (stage : Stage , errorMessage : Option [String ] = None ): Unit = {
977- val serviceTime = stage.latestInfo.submissionTime match {
978- case Some (t) => " %.03f" .format((clock.getTimeMillis() - t) / 1000.0 )
979- case _ => " Unknown"
980- }
981- if (errorMessage.isEmpty) {
982- logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
983- stage.latestInfo.completionTime = Some (clock.getTimeMillis())
984- } else {
985- stage.latestInfo.stageFailed(errorMessage.get)
986- logInfo(" %s (%s) failed in %s s" .format(stage, stage.name, serviceTime))
987- }
988- outputCommitCoordinator.stageEnd(stage.id)
989- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
990- runningStages -= stage
991- }
992973 event.reason match {
993974 case Success =>
994975 listenerBus.post(SparkListenerTaskEnd (stageId, stage.latestInfo.attemptId, taskType,
@@ -1104,7 +1085,6 @@ class DAGScheduler(
11041085 logInfo(s " Marking $failedStage ( ${failedStage.name}) as failed " +
11051086 s " due to a fetch failure from $mapStage ( ${mapStage.name}) " )
11061087 markStageAsFinished(failedStage, Some (failureMessage))
1107- runningStages -= failedStage
11081088 }
11091089
11101090 if (disallowStageRetryForTest) {
@@ -1220,6 +1200,26 @@ class DAGScheduler(
12201200 submitWaitingStages()
12211201 }
12221202
1203+ /**
1204+ * Marks a stage as finished and removes it from the list of running stages.
1205+ */
1206+ private def markStageAsFinished (stage : Stage , errorMessage : Option [String ] = None ): Unit = {
1207+ val serviceTime = stage.latestInfo.submissionTime match {
1208+ case Some (t) => " %.03f" .format((clock.getTimeMillis() - t) / 1000.0 )
1209+ case _ => " Unknown"
1210+ }
1211+ if (errorMessage.isEmpty) {
1212+ logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
1213+ stage.latestInfo.completionTime = Some (clock.getTimeMillis())
1214+ } else {
1215+ stage.latestInfo.stageFailed(errorMessage.get)
1216+ logInfo(" %s (%s) failed in %s s" .format(stage, stage.name, serviceTime))
1217+ }
1218+ outputCommitCoordinator.stageEnd(stage.id)
1219+ listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
1220+ runningStages -= stage
1221+ }
1222+
12231223 /**
12241224 * Aborts all jobs depending on a particular Stage. This is called in response to a task set
12251225 * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
@@ -1269,9 +1269,7 @@ class DAGScheduler(
12691269 if (runningStages.contains(stage)) {
12701270 try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
12711271 taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1272- stage.latestInfo.stageFailed(failureReason)
1273- outputCommitCoordinator.stageEnd(stage.id)
1274- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
1272+ markStageAsFinished(stage, Some (failureReason))
12751273 } catch {
12761274 case e : UnsupportedOperationException =>
12771275 logInfo(s " Could not cancel tasks for stage $stageId" , e)
0 commit comments