Skip to content

Commit b14cd4b

Browse files
committed
fix rdd repartition
1 parent f825847 commit b14cd4b

File tree

6 files changed

+68
-14
lines changed

6 files changed

+68
-14
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,17 @@ private[spark] class MapOutputTrackerMaster(
434434
}
435435
}
436436

437+
/** Unregister all map output information of the given shuffle. */
438+
def unregisterAllMapOutput(shuffleId: Int) {
439+
shuffleStatuses.get(shuffleId) match {
440+
case Some(shuffleStatus) =>
441+
shuffleStatus.removeOutputsByFilter(x => true)
442+
incrementEpoch()
443+
case None =>
444+
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
445+
}
446+
}
447+
437448
/** Unregister shuffle data */
438449
def unregisterShuffle(shuffleId: Int) {
439450
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,4 +559,16 @@ package object config {
559559
.intConf
560560
.checkValue(v => v > 0, "The value should be a positive integer.")
561561
.createWithDefault(2000)
562+
563+
private[spark] val RECOMPUTE_ALL_PARTITIONS_ON_REPARTITION_FAILURE =
564+
ConfigBuilder("spark.shuffle.recomputeAllPartitionsOnRepartitionFailure")
565+
.internal()
566+
.doc("When perform repartition on an RDD, there may be data correctness issue if " +
567+
"only a sub-set of partitions are recomputed on fetch failure and the input data " +
568+
"sequence is not deterministic. Turn on this config to always recompute all the " +
569+
"partitions before the repartition shuffle on fetch failure to ensure we always get " +
570+
"correct result. Please note that turning on this config may increase the risk of job " +
571+
"failing due to reach max consequence stage failure limit.")
572+
.booleanConf
573+
.createWithDefault(true)
562574
}

core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ import org.apache.spark.{Partition, TaskContext}
2727
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
2828
var prev: RDD[T],
2929
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
30-
preservesPartitioning: Boolean = false)
30+
preservesPartitioning: Boolean = false,
31+
recomputeOnFailure: Boolean = false)
3132
extends RDD[U](prev) {
3233

3334
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
@@ -41,4 +42,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
4142
super.clearDependencies()
4243
prev = null
4344
}
45+
46+
override def recomputeAllPartitionsOnFailure(): Boolean = recomputeOnFailure
4447
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ abstract class RDD[T: ClassTag](
452452
/** Distributes elements evenly across output partitions, starting from a random partition. */
453453
val distributePartition = (index: Int, items: Iterator[T]) => {
454454
var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions)
455+
// TODO Enable insert a local sort before shuffle to make input data sequence
456+
// deterministic, thus the config
457+
// "spark.shuffle.recomputeAllPartitionsOnRepartitionFailure" can be disabled.
455458
items.map { t =>
456459
// Note that the hash code of the key will just be the key itself. The HashPartitioner
457460
// will mod it with the number of total partitions.
@@ -461,9 +464,12 @@ abstract class RDD[T: ClassTag](
461464
} : Iterator[(Int, T)]
462465

463466
// include a shuffle step so that our upstream tasks are still distributed
467+
val recomputeOnFailure =
468+
conf.getBoolean("spark.shuffle.recomputeAllPartitionsOnRepartitionFailure", true)
464469
new CoalescedRDD(
465-
new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
466-
new HashPartitioner(numPartitions)),
470+
new ShuffledRDD[Int, T, T](
471+
mapPartitionsWithIndex(distributePartition, recomputeOnFailure),
472+
new HashPartitioner(numPartitions)),
467473
numPartitions,
468474
partitionCoalescer).values
469475
} else {
@@ -837,15 +843,21 @@ abstract class RDD[T: ClassTag](
837843
*
838844
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
839845
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
846+
*
847+
* `recomputeOnFailure` indicates whether to recompute on all the partitions on failure recovery,
848+
* which should be `false` unless the output is not sorted or not sortable, and the output is
849+
* repartitioned.
840850
*/
841851
def mapPartitionsWithIndex[U: ClassTag](
842852
f: (Int, Iterator[T]) => Iterator[U],
843-
preservesPartitioning: Boolean = false): RDD[U] = withScope {
853+
preservesPartitioning: Boolean = false,
854+
recomputeOnFailure: Boolean = false): RDD[U] = withScope {
844855
val cleanedF = sc.clean(f)
845856
new MapPartitionsRDD(
846857
this,
847858
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
848-
preservesPartitioning)
859+
preservesPartitioning,
860+
recomputeOnFailure)
849861
}
850862

851863
/**
@@ -1839,6 +1851,18 @@ abstract class RDD[T: ClassTag](
18391851
def toJavaRDD() : JavaRDD[T] = {
18401852
new JavaRDD(this)(elementClassTag)
18411853
}
1854+
1855+
/**
1856+
* Whether or not the RDD is required to recompute all partitions on failure. Repartition on an
1857+
* RDD performs in a round-robin manner, thus there may be data correctness issue if only a
1858+
* sub-set of partitions are recomputed on failure and the input data sequence is not
1859+
* deterministic. Please refer to SPARK-23207 and SPARK-23243 for related discussion.
1860+
*
1861+
* Require to recompute all partitions on failure if repartition operation is called on this RDD
1862+
* and the result sequence of this RDD is not deterministic (or the data type of the output of
1863+
* this RDD is not sortable).
1864+
*/
1865+
private[spark] def recomputeAllPartitionsOnFailure(): Boolean = false
18421866
}
18431867

18441868

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,17 +1323,17 @@ class DAGScheduler(
13231323
}
13241324

13251325
case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) =>
1326-
val failedStage = stageIdToStage(task.stageId)
1326+
val failedStage = stage
13271327
val mapStage = shuffleIdToMapStage(shuffleId)
13281328

13291329
if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
13301330
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
13311331
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
13321332
s"(attempt ${failedStage.latestInfo.attemptNumber}) running")
13331333
} else {
1334-
failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
1334+
failedStage.failedAttemptIds.add(task.stageAttemptId)
13351335
val shouldAbortStage =
1336-
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
1336+
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
13371337
disallowStageRetryForTest
13381338

13391339
// It is likely that we receive multiple FetchFailed for a single stage (because we have
@@ -1386,8 +1386,12 @@ class DAGScheduler(
13861386
)
13871387
}
13881388
}
1389-
// Mark the map whose fetch failed as broken in the map stage
1390-
if (mapId != -1) {
1389+
1390+
if (mapStage.rdd.recomputeAllPartitionsOnFailure()) {
1391+
// Mark all the map as broken in the map stage, to ensure recompute all the partitions
1392+
// on resubmitted stage attempt.
1393+
mapOutputTracker.unregisterAllMapOutput(shuffleId)
1394+
} else if (mapId != -1) {
13911395
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
13921396
}
13931397

core/src/main/scala/org/apache/spark/scheduler/Stage.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,15 @@ private[scheduler] abstract class Stage(
8282
private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId)
8383

8484
/**
85-
* Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these
86-
* failures in order to avoid endless retries if a stage keeps failing with a FetchFailure.
85+
* Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid
86+
* endless retries if a stage keeps failing.
8787
* We keep track of each attempt ID that has failed to avoid recording duplicate failures if
8888
* multiple tasks from the same stage attempt fail (SPARK-5945).
8989
*/
90-
val fetchFailedAttemptIds = new HashSet[Int]
90+
val failedAttemptIds = new HashSet[Int]
9191

9292
private[scheduler] def clearFailures() : Unit = {
93-
fetchFailedAttemptIds.clear()
93+
failedAttemptIds.clear()
9494
}
9595

9696
/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */

0 commit comments

Comments
 (0)