Skip to content

Commit b148c6f

Browse files
committed
address comments
1 parent eb427f7 commit b148c6f

File tree

8 files changed

+24
-15
lines changed

8 files changed

+24
-15
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,8 @@ private[spark] class DAGScheduler(
13941394
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
13951395
// to save resource.
13961396
if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
1397-
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
1397+
taskScheduler.notifyPartitionCompletion(
1398+
stageId, task.partitionId, event.taskInfo.duration)
13981399
}
13991400

14001401
task match {

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
155155
}
156156
}
157157

158-
def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
158+
// This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
159+
// DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
160+
// synchronized and may hurt the throughput of the scheduler.
161+
def enqueuePartitionCompletionNotification(
162+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
159163
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
160-
scheduler.handlePartitionCompleted(stageId, partitionId)
164+
scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration)
161165
})
162166
}
163167

core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private[spark] trait TaskScheduler {
7070

7171
// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
7272
// and they can skip running tasks for it.
73-
def notifyPartitionCompletion(stageId: Int, partitionId: Int)
73+
def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long)
7474

7575
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
7676
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,9 @@ private[spark] class TaskSchedulerImpl(
301301
}
302302
}
303303

304-
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
305-
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
304+
override def notifyPartitionCompletion(
305+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
306+
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration)
306307
}
307308

308309
/**
@@ -652,9 +653,10 @@ private[spark] class TaskSchedulerImpl(
652653
*/
653654
private[scheduler] def handlePartitionCompleted(
654655
stageId: Int,
655-
partitionId: Int) = synchronized {
656+
partitionId: Int,
657+
taskDuration: Long) = synchronized {
656658
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
657-
tsm.markPartitionCompleted(partitionId)
659+
tsm.markPartitionCompleted(partitionId, taskDuration)
658660
})
659661
}
660662

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -816,12 +816,11 @@ private[spark] class TaskSetManager(
816816
maybeFinishTaskSet()
817817
}
818818

819-
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
819+
private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = {
820820
partitionToIndex.get(partitionId).foreach { index =>
821821
if (!successful(index)) {
822822
if (speculationEnabled && !isZombie) {
823-
// The task is skipped, its duration should be 0.
824-
successfulTaskDurations.insert(0)
823+
successfulTaskDurations.insert(taskDuration)
825824
}
826825
tasksSuccessful += 1
827826
successful(index) = true

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
158158
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
159159
override def killAllTaskAttempts(
160160
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
161-
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
161+
override def notifyPartitionCompletion(
162+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
162163
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
163164
val tasks = ts.tasks.filter(_.partitionId == partitionId)
164165
assert(tasks.length == 1)
@@ -668,7 +669,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
668669
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
669670
throw new UnsupportedOperationException
670671
}
671-
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
672+
override def notifyPartitionCompletion(
673+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
672674
throw new UnsupportedOperationException
673675
}
674676
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}

core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ private class DummyTaskScheduler extends TaskScheduler {
8484
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
8585
override def killAllTaskAttempts(
8686
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
87-
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {}
87+
override def notifyPartitionCompletion(
88+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {}
8889
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
8990
override def defaultParallelism(): Int = 2
9091
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
13941394

13951395
val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get
13961396
assert(taskSetManager.runningTasks === 8)
1397-
taskSetManager.markPartitionCompleted(8)
1397+
taskSetManager.markPartitionCompleted(8, 0)
13981398
assert(!taskSetManager.successfulTaskDurations.isEmpty())
13991399
taskSetManager.checkSpeculatableTasks(0)
14001400
}

0 commit comments

Comments
 (0)