Skip to content

Commit 0423970

Browse files
committed
Addressed some of the comments both from Mridul and Thomas
1 parent 2688df2 commit 0423970

File tree

9 files changed

+65
-44
lines changed

9 files changed

+65
-44
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
9898
shuffleId, this)
9999

100100
// By default, shuffle merge is enabled for ShuffleDependency if push based shuffle is enabled
101-
private[this] var _shuffleMergeEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf)
101+
private[spark] var _shuffleMergeEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf)
102102

103103
def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = {
104104
_shuffleMergeEnabled = shuffleMergeEnabled
@@ -110,7 +110,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
110110
* Stores the location of the list of chosen external shuffle services for handling the
111111
* shuffle merge requests from mappers in this shuffle map stage.
112112
*/
113-
private[this] var _mergerLocs: Seq[BlockManagerId] = Nil
113+
private[spark] var _mergerLocs: Seq[BlockManagerId] = Nil
114114

115115
def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
116116
if (mergerLocs != null && mergerLocs.length > 0) {

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,29 +1929,32 @@ package object config {
19291929
.createWithDefault(false)
19301930

19311931
private[spark] val PUSH_BASED_SHUFFLE_ENABLED =
1932-
ConfigBuilder("spark.shuffle.push.based.enabled")
1932+
ConfigBuilder("spark.shuffle.push.enabled")
19331933
.doc("Set to 'true' to enable push based shuffle")
19341934
.booleanConf
19351935
.createWithDefault(false)
19361936

19371937
private[spark] val MAX_MERGER_LOCATIONS_CACHED =
19381938
ConfigBuilder("spark.shuffle.push.retainedMergerLocations")
1939-
.doc("Max number of shuffle services hosts info cached to determine the locations of" +
1940-
" shuffle services when pushing the blocks.")
1939+
.doc("Maximum number of shuffle push mergers locations cached for push based shuffle." +
1940+
"Currently Shuffle push merger locations are nothing but shuffle services where an" +
1941+
"executor is launched in the case of Push based shuffle.")
19411942
.intConf
19421943
.createWithDefault(500)
19431944

19441945
private[spark] val MERGER_LOCATIONS_MIN_THRESHOLD_RATIO =
1945-
ConfigBuilder("spark.shuffle.push.mergerLocations.minThresholdRatio")
1946-
.doc("Minimum percentage of shuffle services (merger locations) should be available with" +
1947-
" respect to numPartitions in order to enable push based shuffle for a stage.")
1946+
ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio")
1947+
.doc("Minimum percentage of shuffle push mergers locations required to enable push based" +
1948+
"shuffle for the stage with respect to number of partitions of the child stage. This is" +
1949+
" the number of unique Node Manager locations needed to enable push based shuffle.")
19481950
.doubleConf
19491951
.createWithDefault(0.05)
19501952

19511953
private[spark] val MERGER_LOCATIONS_MIN_STATIC_THRESHOLD =
1952-
ConfigBuilder("spark.shuffle.push.mergerLocations.minStaticThreshold")
1953-
.doc("Minimum number of shuffle services (merger locations) should be available in order" +
1954-
"to enable push based shuffle for a stage.")
1954+
ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold")
1955+
.doc("Minimum static number of of shuffle push mergers locations should be available in" +
1956+
" order to enable push based shuffle for a stage. Note this config works in" +
1957+
" conjunction with spark.shuffle.push.mergersMinThresholdRatio")
19551958
.doubleConf
19561959
.createWithDefault(5)
19571960
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,15 +1256,19 @@ private[spark] class DAGScheduler(
12561256

12571257
/**
12581258
* If push based shuffle is enabled, set the shuffle services to be used for the given
1259-
* shuffle map stage. The list of shuffle services is determined based on the list of
1260-
* active executors tracked by block manager master at the start of the stage.
1259+
* shuffle map stage for block push/merge.
1260+
*
1261+
* Even with DRA kicking in and significantly reducing the number of available active
1262+
* executors, we would still be able to get sufficient shuffle service locations for
1263+
* block push/merge by getting the historical locations of past executors.
12611264
*/
12621265
private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage) {
12631266
// TODO: Handle stage reuse/retry cases separately as without finalize changes we cannot
12641267
// TODO: disable shuffle merge for the retry/reuse cases
1265-
val mergerLocs = sc.schedulerBackend.getMergerLocations(
1268+
val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
12661269
stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
1267-
logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
1270+
logDebug(s"List of shuffle push merger locations " +
1271+
s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
12681272

12691273
if (mergerLocs.nonEmpty) {
12701274
stage.shuffleDep.setMergerLocs(mergerLocs)

core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,14 @@ private[spark] trait SchedulerBackend {
9595

9696
/**
9797
* Get the list of both active and dead executors host locations for push based shuffle
98+
*
99+
* Currently push based shuffle is disabled for both stage retry and stage reuse cases
100+
* (for eg: in the case where few partitions are lost due to failure). Hence this method
101+
* should be invoked only once for a ShuffleDependency.
98102
* @return List of external shuffle services locations
99103
*/
100-
def getMergerLocations(numPartitions: Int, resourceProfileId: Int): Seq[BlockManagerId] = Nil
104+
def getShufflePushMergerLocations(
105+
numPartitions: Int,
106+
resourceProfileId: Int): Seq[BlockManagerId] = Nil
101107

102108
}

core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.concurrent.Future
2424
import org.apache.spark.{SparkConf, SparkException}
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.rpc.RpcEndpointRef
27-
import org.apache.spark.storage.BlockManagerMessages.{GetMergerLocations, _}
27+
import org.apache.spark.storage.BlockManagerMessages.{GetShufflePushMergerLocations, _}
2828
import org.apache.spark.util.{RpcUtils, ThreadUtils}
2929

3030
private[spark]
@@ -126,13 +126,14 @@ class BlockManagerMaster(
126126
}
127127

128128
/**
129-
* Get list of shuffle service locations available for pushing the shuffle blocks
130-
* with push based shuffle
129+
* Get list of unique shuffle service locations where an executor is successfully
130+
* registered in the past for block push/merge with push based shuffle.
131131
*/
132-
def getMergerLocations(
132+
def getShufflePushMergerLocations(
133133
numMergersNeeded: Int,
134134
hostsToFilter: Set[String]): Seq[BlockManagerId] = {
135-
driverEndpoint.askSync[Seq[BlockManagerId]](GetMergerLocations(numMergersNeeded, hostsToFilter))
135+
driverEndpoint.askSync[Seq[BlockManagerId]](
136+
GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter))
136137
}
137138

138139
def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = {

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ class BlockManagerMasterEndpoint(
7474
// Mapping from block id to the set of block managers that have the block.
7575
private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
7676

77-
// Mapping from host name to shuffle (mergers) services
78-
private val mergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]()
77+
// Mapping from host name to shuffle (mergers) services where the current app
78+
// registered an executor in the past. Older hosts are removed when the
79+
// maxRetainedMergerLocations size is reached in favor of newer locations.
80+
private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]()
7981

8082
// Maximum number of merger locations to cache
8183
private val maxRetainedMergerLocations = conf.get(config.MAX_MERGER_LOCATIONS_CACHED)
@@ -145,8 +147,8 @@ class BlockManagerMasterEndpoint(
145147
case GetBlockStatus(blockId, askStorageEndpoints) =>
146148
context.reply(blockStatus(blockId, askStorageEndpoints))
147149

148-
case GetMergerLocations(numMergersNeeded, hostsToFilter) =>
149-
context.reply(getMergerLocations(numMergersNeeded, hostsToFilter))
150+
case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) =>
151+
context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter))
150152

151153
case IsExecutorAlive(executorId) =>
152154
context.reply(blockManagerIdByExecutor.contains(executorId))
@@ -370,13 +372,13 @@ class BlockManagerMasterEndpoint(
370372
}
371373

372374
private def addMergerLocation(blockManagerId: BlockManagerId): Unit = {
373-
if (!mergerLocations.contains(blockManagerId.host) && !blockManagerId.isDriver) {
375+
if (!shuffleMergerLocations.contains(blockManagerId.host) && !blockManagerId.isDriver) {
374376
val shuffleServerId = BlockManagerId(blockManagerId.executorId, blockManagerId.host,
375377
StorageUtils.externalShuffleServicePort(conf))
376-
if (mergerLocations.size >= maxRetainedMergerLocations) {
377-
mergerLocations -= mergerLocations.head._1
378+
if (shuffleMergerLocations.size >= maxRetainedMergerLocations) {
379+
shuffleMergerLocations -= shuffleMergerLocations.head._1
378380
}
379-
mergerLocations(shuffleServerId.host) = shuffleServerId
381+
shuffleMergerLocations(shuffleServerId.host) = shuffleServerId
380382
}
381383
}
382384

@@ -679,11 +681,12 @@ class BlockManagerMasterEndpoint(
679681
}
680682
}
681683

682-
private def getMergerLocations(
684+
private def getShufflePushMergerLocations(
683685
numMergersNeeded: Int,
684686
hostsToFilter: Set[String]): Seq[BlockManagerId] = {
685-
// Copying the merger locations to a list so that the original mergerLocations won't be shuffled
686-
val mergers = mergerLocations.values.filterNot(x => hostsToFilter.contains(x.host)).toSeq
687+
// Copying the merger locations to a list so that the original
688+
// shuffleMergerLocations won't be shuffled
689+
val mergers = shuffleMergerLocations.values.filterNot(x => hostsToFilter.contains(x.host)).toSeq
687690
Utils.randomize(mergers).take(numMergersNeeded)
688691
}
689692

core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ private[spark] object BlockManagerMessages {
142142

143143
case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster
144144

145-
case class GetMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String])
145+
case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String])
146146
extends ToBlockManagerMaster
147147

148148
}

core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,7 +1976,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
19761976

19771977
test("mergerLocations should be bounded with in" +
19781978
" spark.shuffle.push.retainedMergerLocations") {
1979-
assert(master.getMergerLocations(10, Set.empty).isEmpty)
1979+
assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty)
19801980
makeBlockManager(100, "execA",
19811981
transferService = Some(new MockBlockTransferService(10, "hostA")))
19821982
makeBlockManager(100, "execB",
@@ -1987,10 +1987,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
19871987
transferService = Some(new MockBlockTransferService(10, "hostD")))
19881988
makeBlockManager(100, "execE",
19891989
transferService = Some(new MockBlockTransferService(10, "hostA")))
1990-
assert(master.getMergerLocations(10, Set.empty).size == 4)
1991-
assert(master.getMergerLocations(10, Set.empty)
1990+
assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4)
1991+
assert(master.getShufflePushMergerLocations(10, Set.empty)
19921992
.exists(x => Seq("hostC", "hostD", "hostA", "hostB").contains(x.host)))
1993-
assert(master.getMergerLocations(10, Set("hostB")).size == 3)
1993+
assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3)
19941994
}
19951995

19961996
class MockBlockTransferService(

resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,24 @@ package org.apache.spark.scheduler.cluster
1919

2020
import java.util.EnumSet
2121
import java.util.concurrent.atomic.AtomicBoolean
22+
2223
import javax.servlet.DispatcherType
2324

2425
import scala.concurrent.{ExecutionContext, Future}
2526
import scala.util.{Failure, Success}
2627
import scala.util.control.NonFatal
27-
2828
import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
29-
3029
import org.apache.spark.SparkContext
3130
import org.apache.spark.deploy.security.HadoopDelegationTokenManager
32-
import org.apache.spark.internal.{config, Logging}
31+
import org.apache.spark.internal.{Logging, config}
3332
import org.apache.spark.internal.config.DYN_ALLOCATION_MAX_EXECUTORS
3433
import org.apache.spark.internal.config.UI._
3534
import org.apache.spark.resource.ResourceProfile
3635
import org.apache.spark.rpc._
3736
import org.apache.spark.scheduler._
3837
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
3938
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
40-
import org.apache.spark.util.{RpcUtils, ThreadUtils}
39+
import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
4140

4241
/**
4342
* Abstract Yarn scheduler backend that contains common logic
@@ -170,19 +169,24 @@ private[spark] abstract class YarnSchedulerBackend(
170169
totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
171170
}
172171

173-
override def getMergerLocations(
172+
override def getShufflePushMergerLocations(
174173
numPartitions: Int,
175174
resourceProfileId: Int): Seq[BlockManagerId] = {
176175
// Currently this is naive way of calculating numMergersNeeded for a stage. In future,
177176
// we can use better heuristics to calculate numMergersNeeded for a stage.
177+
val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) {
178+
maxNumExecutors
179+
} else {
180+
Int.MaxValue
181+
}
178182
val tasksPerExecutor = sc.resourceProfileManager
179183
.resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf)
180184
val numMergersNeeded = math.min(
181-
math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxNumExecutors)
185+
math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors)
182186
val minMergersThreshold = math.max(minMergersStaticThreshold,
183187
math.floor(numMergersNeeded * minMergersThresholdRatio).toInt)
184188
val mergerLocations = blockManagerMaster
185-
.getMergerLocations(numMergersNeeded, scheduler.nodeBlacklist())
189+
.getShufflePushMergerLocations(numMergersNeeded, scheduler.nodeBlacklist())
186190
logDebug(s"Num merger locations available ${mergerLocations.length}")
187191
if (mergerLocations.size < numMergersNeeded && mergerLocations.size < minMergersThreshold) {
188192
Seq.empty[BlockManagerId]

0 commit comments

Comments
 (0)