Skip to content

Commit e96ce3a

Browse files
committed
Consolidate stage completion handling code in a single method.
1 parent 3052aea commit e96ce3a

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -713,10 +713,10 @@ class DAGScheduler(
713713
// cancelling the stages because if the DAG scheduler is stopped, the entire application
714714
// is in the process of getting stopped.
715715
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
716-
runningStages.foreach { stage =>
717-
stage.latestInfo.stageFailed(stageFailedMessage)
718-
outputCommitCoordinator.stageEnd(stage.id)
719-
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
716+
// The `toArray` here is necessary so that we don't iterate over `runningStages` while
717+
// mutating it.
718+
runningStages.toArray.foreach { stage =>
719+
markStageAsFinished(stage, Some(stageFailedMessage))
720720
}
721721
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
722722
}
@@ -891,10 +891,9 @@ class DAGScheduler(
891891
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
892892
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
893893
} else {
894-
// Because we posted SparkListenerStageSubmitted earlier, we should post
895-
// SparkListenerStageCompleted here in case there are no tasks to run.
896-
outputCommitCoordinator.stageEnd(stage.id)
897-
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
894+
// Because we posted SparkListenerStageSubmitted earlier, we should mark
895+
// the stage as completed here in case there are no tasks to run
896+
markStageAsFinished(stage, None)
898897

899898
val debugString = stage match {
900899
case stage: ShuffleMapStage =>
@@ -906,7 +905,6 @@ class DAGScheduler(
906905
s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
907906
}
908907
logDebug(debugString)
909-
runningStages -= stage
910908
}
911909
}
912910

@@ -972,23 +970,6 @@ class DAGScheduler(
972970
}
973971

974972
val stage = stageIdToStage(task.stageId)
975-
976-
def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
977-
val serviceTime = stage.latestInfo.submissionTime match {
978-
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
979-
case _ => "Unknown"
980-
}
981-
if (errorMessage.isEmpty) {
982-
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
983-
stage.latestInfo.completionTime = Some(clock.getTimeMillis())
984-
} else {
985-
stage.latestInfo.stageFailed(errorMessage.get)
986-
logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
987-
}
988-
outputCommitCoordinator.stageEnd(stage.id)
989-
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
990-
runningStages -= stage
991-
}
992973
event.reason match {
993974
case Success =>
994975
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
@@ -1104,7 +1085,6 @@ class DAGScheduler(
11041085
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
11051086
s"due to a fetch failure from $mapStage (${mapStage.name})")
11061087
markStageAsFinished(failedStage, Some(failureMessage))
1107-
runningStages -= failedStage
11081088
}
11091089

11101090
if (disallowStageRetryForTest) {
@@ -1220,6 +1200,26 @@ class DAGScheduler(
12201200
submitWaitingStages()
12211201
}
12221202

1203+
/**
1204+
* Marks a stage as finished and removes it from the list of running stages.
1205+
*/
1206+
private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
1207+
val serviceTime = stage.latestInfo.submissionTime match {
1208+
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
1209+
case _ => "Unknown"
1210+
}
1211+
if (errorMessage.isEmpty) {
1212+
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
1213+
stage.latestInfo.completionTime = Some(clock.getTimeMillis())
1214+
} else {
1215+
stage.latestInfo.stageFailed(errorMessage.get)
1216+
logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
1217+
}
1218+
outputCommitCoordinator.stageEnd(stage.id)
1219+
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
1220+
runningStages -= stage
1221+
}
1222+
12231223
/**
12241224
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
12251225
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
@@ -1269,9 +1269,7 @@ class DAGScheduler(
12691269
if (runningStages.contains(stage)) {
12701270
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
12711271
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1272-
stage.latestInfo.stageFailed(failureReason)
1273-
outputCommitCoordinator.stageEnd(stage.id)
1274-
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
1272+
markStageAsFinished(stage, Some(failureReason))
12751273
} catch {
12761274
case e: UnsupportedOperationException =>
12771275
logInfo(s"Could not cancel tasks for stage $stageId", e)

0 commit comments

Comments
 (0)