Skip to content

Commit 6334f80

Browse files
committed
[CORE] Fix regressions in decommissioning
The DecommissionWorkerSuite started becoming flaky and it revealed a real regression. Recent PR's (#28085 and #29211) neccessitate a small reworking of the decommissioning logic. Before getting into that, let me describe the intended behavior of decommissioning: If a fetch failure happens where the source executor was decommissioned, we want to treat that as an eager signal to clear all shuffle state associated with that executor. In addition if we know that the host was decommissioned, we want to forget about all map statuses from all other executors on that decommissioned host. This is what the test "decommission workers ensure that fetch failures lead to rerun" is trying to test. This invariant is important to ensure that decommissioning a host does not lead to multiple fetch failures that might fail the job. - Per #29211, the executors now eagerly exit on decommissioning and thus the executor is lost before the fetch failure even happens. (I tested this by waiting some seconds before triggering the fetch failure). When an executor is lost, we forget its decommissioning information. The fix is to keep the decommissioning information around for some time after removal with some extra logic to finally purge it after a timeout. - Per #28085, when the executor is lost, it forgets the shuffle state about just that executor and increments the shuffleFileLostEpoch. This incrementing precludes the clearing of state of the entire host when the fetch failure happens. This PR elects to only change this codepath for the special case of decommissioning, without any other side effects. This whole version keeping stuff is complex and it has effectively not been semantically changed since 2013! The fix here is also simple: Ignore the shuffleFileLostEpoch when the shuffle status is being cleared due to a fetch failure resulting from host decommission. These two fixes are local to decommissioning only and don't change other behavior. Also added some more tests to TaskSchedulerImpl to ensure that the decommissioning information is indeed purged after a timeout. Also hardened the test DecommissionWorkerSuite to make it wait for successful job completion.
1 parent 0c850c7 commit 6334f80

File tree

4 files changed

+113
-34
lines changed

4 files changed

+113
-34
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,8 @@ private[spark] class DAGScheduler(
18461846
execId = bmAddress.executorId,
18471847
fileLost = true,
18481848
hostToUnregisterOutputs = hostToUnregisterOutputs,
1849-
maybeEpoch = Some(task.epoch))
1849+
maybeEpoch = Some(task.epoch),
1850+
ignoreShuffleVersion = isHostDecommissioned)
18501851
}
18511852
}
18521853

@@ -2012,7 +2013,8 @@ private[spark] class DAGScheduler(
20122013
execId: String,
20132014
fileLost: Boolean,
20142015
hostToUnregisterOutputs: Option[String],
2015-
maybeEpoch: Option[Long] = None): Unit = {
2016+
maybeEpoch: Option[Long] = None,
2017+
ignoreShuffleVersion: Boolean = false): Unit = {
20162018
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
20172019
logDebug(s"Considering removal of executor $execId; " +
20182020
s"fileLost: $fileLost, currentEpoch: $currentEpoch")
@@ -2022,16 +2024,23 @@ private[spark] class DAGScheduler(
20222024
blockManagerMaster.removeExecutor(execId)
20232025
clearCacheLocs()
20242026
}
2025-
if (fileLost &&
2026-
(!shuffleFileLostEpoch.contains(execId) || shuffleFileLostEpoch(execId) < currentEpoch)) {
2027-
shuffleFileLostEpoch(execId) = currentEpoch
2028-
hostToUnregisterOutputs match {
2029-
case Some(host) =>
2030-
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
2031-
mapOutputTracker.removeOutputsOnHost(host)
2032-
case None =>
2033-
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
2034-
mapOutputTracker.removeOutputsOnExecutor(execId)
2027+
if (fileLost) {
2028+
val remove = if (!shuffleFileLostEpoch.contains(execId) ||
2029+
shuffleFileLostEpoch(execId) < currentEpoch) {
2030+
shuffleFileLostEpoch(execId) = currentEpoch
2031+
true
2032+
} else {
2033+
ignoreShuffleVersion
2034+
}
2035+
if (remove) {
2036+
hostToUnregisterOutputs match {
2037+
case Some(host) =>
2038+
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
2039+
mapOutputTracker.removeOutputsOnHost(host)
2040+
case None =>
2041+
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
2042+
mapOutputTracker.removeOutputsOnExecutor(execId)
2043+
}
20352044
}
20362045
}
20372046
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.nio.ByteBuffer
21+
import java.util
2122
import java.util.{Timer, TimerTask}
2223
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
2324
import java.util.concurrent.atomic.AtomicLong
@@ -136,7 +137,9 @@ private[spark] class TaskSchedulerImpl(
136137
// IDs of the tasks running on each executor
137138
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
138139

139-
private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
140+
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
141+
// map of second to list of executors to clear form the above map
142+
val decommissioningExecutorsToGc = new util.TreeMap[Long, mutable.ArrayBuffer[String]]()
140143

141144
def runningTasksByExecutors: Map[String, Int] = synchronized {
142145
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
@@ -910,7 +913,7 @@ private[spark] class TaskSchedulerImpl(
910913
// if we heard isHostDecommissioned ever true, then we keep that one since it is
911914
// most likely coming from the cluster manager and thus authoritative
912915
val oldDecomInfo = executorsPendingDecommission.get(executorId)
913-
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
916+
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
914917
executorsPendingDecommission(executorId) = decommissionInfo
915918
}
916919
}
@@ -921,7 +924,13 @@ private[spark] class TaskSchedulerImpl(
921924

922925
override def getExecutorDecommissionInfo(executorId: String)
923926
: Option[ExecutorDecommissionInfo] = synchronized {
924-
executorsPendingDecommission.get(executorId)
927+
import scala.collection.JavaConverters._
928+
// Garbage collect old decommissioning entries
929+
val secondsToGcUptil = TimeUnit.MILLISECONDS.toSeconds(clock.getTimeMillis())
930+
val headMap = decommissioningExecutorsToGc.headMap(secondsToGcUptil)
931+
headMap.values().asScala.flatten.foreach(executorsPendingDecommission -= _)
932+
headMap.clear()
933+
executorsPendingDecommission.get(executorId)
925934
}
926935

927936
override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
@@ -1027,7 +1036,15 @@ private[spark] class TaskSchedulerImpl(
10271036
}
10281037
}
10291038

1030-
executorsPendingDecommission -= executorId
1039+
1040+
val decomInfo = executorsPendingDecommission.get(executorId)
1041+
if (decomInfo.isDefined) {
1042+
val rememberSeconds =
1043+
conf.getInt("spark.decommissioningRememberAfterRemoval.seconds", 60)
1044+
val gcSecond = TimeUnit.MILLISECONDS.toSeconds(clock.getTimeMillis()) + rememberSeconds
1045+
decommissioningExecutorsToGc.computeIfAbsent(gcSecond, _ => mutable.ArrayBuffer.empty) +=
1046+
executorId
1047+
}
10311048

10321049
if (reason != LossReasonPending) {
10331050
executorIdToHost -= executorId

core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ class DecommissionWorkerSuite
8484
}
8585
}
8686

87+
// Unlike TestUtils.withListener, it also waits for the job to be done
88+
def withListener(sc: SparkContext, listener: RootStageAwareListener)
89+
(body: SparkListener => Unit): Unit = {
90+
sc.addSparkListener(listener)
91+
try {
92+
body(listener)
93+
sc.listenerBus.waitUntilEmpty()
94+
listener.waitForJobDone()
95+
} finally {
96+
sc.listenerBus.removeListener(listener)
97+
}
98+
}
99+
87100
test("decommission workers should not result in job failure") {
88101
val maxTaskFailures = 2
89102
val numTimesToKillWorkers = maxTaskFailures + 1
@@ -109,7 +122,7 @@ class DecommissionWorkerSuite
109122
}
110123
}
111124
}
112-
TestUtils.withListener(sc, listener) { _ =>
125+
withListener(sc, listener) { _ =>
113126
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
114127
Thread.sleep(5 * 1000L); 1
115128
}.count()
@@ -164,7 +177,7 @@ class DecommissionWorkerSuite
164177
}
165178
}
166179
}
167-
TestUtils.withListener(sc, listener) { _ =>
180+
withListener(sc, listener) { _ =>
168181
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
169182
val sleepTimeSeconds = if (pid == 0) 1 else 10
170183
Thread.sleep(sleepTimeSeconds * 1000L)
@@ -212,22 +225,27 @@ class DecommissionWorkerSuite
212225
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
213226
val taskInfo = taskEnd.taskInfo
214227
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
215-
taskEnd.stageAttemptId == 0) {
228+
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
216229
decommissionWorkerOnMaster(workerToDecom,
217230
"decommission worker after task on it is done")
218231
}
219232
}
220233
}
221-
TestUtils.withListener(sc, listener) { _ =>
234+
withListener(sc, listener) { _ =>
222235
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
223236
val executorId = SparkEnv.get.executorId
224-
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
225-
Thread.sleep(sleepTimeSeconds * 1000L)
237+
val context = TaskContext.get()
238+
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
239+
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
240+
Thread.sleep(sleepTimeSeconds * 1000L)
241+
}
226242
List(1).iterator
227243
}, preservesPartitioning = true)
228244
.repartition(1).mapPartitions(iter => {
229245
val context = TaskContext.get()
230246
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
247+
// Wait a bit for the decommissioning to be triggered in the listener
248+
Thread.sleep(5000)
231249
// MapIndex is explicitly -1 to force the entire host to be decommissioned
232250
// However, this will cause both the tasks in the preceding stage since the host here is
233251
// "localhost" (shortcoming of this single-machine unit test in that all the workers
@@ -265,23 +283,31 @@ class DecommissionWorkerSuite
265283
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
266284
jobEnd.jobResult match {
267285
case JobSucceeded => jobDone.set(true)
286+
case JobFailed(exception) => logError(s"Job failed", exception)
268287
}
269288
}
270289

271290
protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}
272291

273292
protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}
274293

294+
private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
295+
String = {
296+
s"${stageId}:${stageAttemptId}:" +
297+
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
298+
}
299+
275300
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
301+
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
302+
logInfo(s"Task started: $signature")
276303
if (isRootStageId(taskStart.stageId)) {
277304
rootTasksStarted.add(taskStart.taskInfo)
278305
handleRootTaskStart(taskStart)
279306
}
280307
}
281308

282309
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
283-
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
284-
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
310+
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
285311
logInfo(s"Task End $taskSignature")
286312
tasksFinished.add(taskSignature)
287313
if (isRootStageId(taskEnd.stageId)) {
@@ -291,8 +317,13 @@ class DecommissionWorkerSuite
291317
}
292318

293319
def getTasksFinished(): Seq[String] = {
294-
assert(jobDone.get(), "Job isn't successfully done yet")
295-
tasksFinished.asScala.toSeq
320+
tasksFinished.asScala.toList
321+
}
322+
323+
def waitForJobDone(): Unit = {
324+
eventually(timeout(10.seconds), interval(100.milliseconds)) {
325+
assert(jobDone.get(), "Job isn't successfully done yet")
326+
}
296327
}
297328
}
298329

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.internal.config
3434
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, TaskResourceRequests}
3535
import org.apache.spark.resource.ResourceUtils._
3636
import org.apache.spark.resource.TestResourceIDs._
37-
import org.apache.spark.util.ManualClock
37+
import org.apache.spark.util.{Clock, ManualClock, SystemClock}
3838

3939
class FakeSchedulerBackend extends SchedulerBackend {
4040
def start(): Unit = {}
@@ -88,10 +88,15 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
8888
}
8989

9090
def setupSchedulerWithMaster(master: String, confs: (String, String)*): TaskSchedulerImpl = {
91+
setupSchedulerWithMasterAndClock(master, new SystemClock, confs: _*)
92+
}
93+
94+
def setupSchedulerWithMasterAndClock(master: String, clock: Clock, confs: (String, String)*):
95+
TaskSchedulerImpl = {
9196
val conf = new SparkConf().setMaster(master).setAppName("TaskSchedulerImplSuite")
9297
confs.foreach { case (k, v) => conf.set(k, v) }
9398
sc = new SparkContext(conf)
94-
taskScheduler = new TaskSchedulerImpl(sc)
99+
taskScheduler = new TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES), clock = clock)
95100
setupHelper()
96101
}
97102

@@ -1802,9 +1807,10 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18021807
assert(2 == taskDescriptions.head.resources(GPU).addresses.size)
18031808
}
18041809

1805-
private def setupSchedulerForDecommissionTests(): TaskSchedulerImpl = {
1806-
val taskScheduler = setupSchedulerWithMaster(
1810+
private def setupSchedulerForDecommissionTests(clock: Clock): TaskSchedulerImpl = {
1811+
val taskScheduler = setupSchedulerWithMasterAndClock(
18071812
s"local[2]",
1813+
clock,
18081814
config.CPUS_PER_TASK.key -> 1.toString)
18091815
taskScheduler.submitTasks(FakeTask.createTaskSet(2))
18101816
val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1),
@@ -1815,7 +1821,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18151821
}
18161822

18171823
test("scheduler should keep the decommission info where host was decommissioned") {
1818-
val scheduler = setupSchedulerForDecommissionTests()
1824+
val scheduler = setupSchedulerForDecommissionTests(new SystemClock)
18191825

18201826
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false))
18211827
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true))
@@ -1829,8 +1835,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18291835
assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty)
18301836
}
18311837

1832-
test("scheduler should ignore decommissioning of removed executors") {
1833-
val scheduler = setupSchedulerForDecommissionTests()
1838+
test("scheduler should eventually purge removed and decommissioned executors") {
1839+
val clock = new ManualClock(10000L)
1840+
val scheduler = setupSchedulerForDecommissionTests(clock)
18341841

18351842
// executor 0 is decommissioned after loosing
18361843
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
@@ -1839,14 +1846,29 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18391846
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false))
18401847
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
18411848

1849+
assert(scheduler.executorsPendingDecommission.isEmpty)
1850+
clock.advance(5000)
1851+
18421852
// executor 1 is decommissioned before loosing
18431853
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
18441854
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
18451855
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
1856+
clock.advance(2000)
18461857
scheduler.executorLost("executor1", ExecutorExited(0, false, "normal"))
1847-
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1858+
assert(scheduler.decommissioningExecutorsToGc.size === 1)
1859+
assert(scheduler.executorsPendingDecommission.size === 1)
1860+
clock.advance(2000)
1861+
// It hasn't been 60 seconds yet before removal
1862+
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
18481863
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
1864+
clock.advance(2000)
1865+
assert(scheduler.decommissioningExecutorsToGc.size === 1)
1866+
assert(scheduler.executorsPendingDecommission.size === 1)
1867+
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
1868+
clock.advance(61000)
18491869
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1870+
assert(scheduler.decommissioningExecutorsToGc.isEmpty)
1871+
assert(scheduler.executorsPendingDecommission.isEmpty)
18501872
}
18511873

18521874
/**

0 commit comments

Comments
 (0)