Skip to content

Commit bfe83f0

Browse files
committed
Backtrace RDD dependency tree to find all RDDs that belong to a Stage
The Stage boundary is marked by shuffle dependencies. When one or more RDD are related by narrow dependencies, they should all be associated with the same Stage. Following backward narrow dependency pointers allows StageInfo to hold the information of all relevant RDDs, rather than just the last one associated with the Stage. This commit also moves RDDInfo to its own file.
1 parent 3a390bf commit bfe83f0

File tree

13 files changed

+175
-76
lines changed

13 files changed

+175
-76
lines changed

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class TaskContext(
3333
val attemptId: Long,
3434
val runningLocally: Boolean = false,
3535
@volatile var interrupted: Boolean = false,
36-
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
36+
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty
3737
) extends Serializable {
3838

3939
@deprecated("use partitionId", "0.8.1")

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class TaskMetrics extends Serializable {
8484
}
8585

8686
private[spark] object TaskMetrics {
87-
def empty(): TaskMetrics = new TaskMetrics
87+
def empty: TaskMetrics = new TaskMetrics
8888
}
8989

9090

core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
207207
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
208208
val taskInfo = taskEnd.taskInfo
209209
var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType)
210-
val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty()
210+
val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty
211211
taskEnd.reason match {
212212
case Success => taskStatus += " STATUS=SUCCESS"
213213
recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics)

core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,19 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.apache.spark.NarrowDependency
2023
import org.apache.spark.annotation.DeveloperApi
24+
import org.apache.spark.rdd.RDD
2125
import org.apache.spark.storage.RDDInfo
2226

2327
/**
2428
* :: DeveloperApi ::
2529
* Stores information about a stage to pass from the scheduler to SparkListeners.
2630
*/
2731
@DeveloperApi
28-
class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) {
32+
class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo]) {
2933
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
3034
var submissionTime: Option[Long] = None
3135
/** Time when all tasks in the stage completed or when the stage was cancelled. */
@@ -41,12 +45,37 @@ class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddIn
4145
}
4246
}
4347

44-
private[spark]
45-
object StageInfo {
48+
private[spark] object StageInfo {
49+
/**
50+
* Construct a StageInfo from a Stage.
51+
*
52+
* Each Stage is associated with one or many RDDs, with the boundary of a Stage marked by
53+
* shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
54+
* sequence of narrow dependencies should also be associated with this Stage.
55+
*/
4656
def fromStage(stage: Stage): StageInfo = {
47-
val rdd = stage.rdd
48-
val rddName = Option(rdd.name).getOrElse(rdd.id.toString)
49-
val rddInfo = new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel)
50-
new StageInfo(stage.id, stage.name, stage.numTasks, rddInfo)
57+
val ancestorRddInfos = getNarrowAncestors(stage.rdd).map(RDDInfo.fromRdd)
58+
val rddInfos = ancestorRddInfos ++ Seq(RDDInfo.fromRdd(stage.rdd))
59+
new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos)
60+
}
61+
62+
/**
63+
* Return the ancestors of the given RDD that are related to it only through a sequence of
64+
* narrow dependencies. This traverses the given RDD's dependency tree using DFS.
65+
*/
66+
private def getNarrowAncestors(
67+
rdd: RDD[_],
68+
ancestors: ArrayBuffer[RDD[_]] = ArrayBuffer.empty): Seq[RDD[_]] = {
69+
val narrowParents = getNarrowDependencies(rdd).map(_.rdd)
70+
narrowParents.foreach { parent =>
71+
ancestors += parent
72+
getNarrowAncestors(parent, ancestors)
73+
}
74+
ancestors
75+
}
76+
77+
/** Return the narrow dependencies of the given RDD. */
78+
private def getNarrowDependencies(rdd: RDD[_]): Seq[NarrowDependency[_]] = {
79+
rdd.dependencies.collect { case d: NarrowDependency[_] => d }
5180
}
5281
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.storage
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.util.Utils
23+
24+
@DeveloperApi
25+
class RDDInfo(
26+
val id: Int,
27+
val name: String,
28+
val numPartitions: Int,
29+
val storageLevel: StorageLevel)
30+
extends Ordered[RDDInfo] {
31+
32+
var numCachedPartitions = 0
33+
var memSize = 0L
34+
var diskSize = 0L
35+
var tachyonSize = 0L
36+
37+
override def toString = {
38+
import Utils.bytesToString
39+
("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " +
40+
"TachyonSize: %s; DiskSize: %s").format(
41+
name, id, storageLevel.toString, numCachedPartitions, numPartitions,
42+
bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize))
43+
}
44+
45+
override def compare(that: RDDInfo) = {
46+
this.id - that.id
47+
}
48+
}
49+
50+
private[spark] object RDDInfo {
51+
def fromRdd(rdd: RDD[_]): RDDInfo = {
52+
val rddName = Option(rdd.name).getOrElse(rdd.id.toString)
53+
new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel)
54+
}
55+
}

core/src/main/scala/org/apache/spark/storage/StorageUtils.scala

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,60 +21,30 @@ import scala.collection.Map
2121
import scala.collection.mutable
2222

2323
import org.apache.spark.SparkContext
24-
import org.apache.spark.annotation.DeveloperApi
25-
import org.apache.spark.util.Utils
2624

27-
private[spark]
28-
class StorageStatus(
25+
/** Storage information for each BlockManager. */
26+
private[spark] class StorageStatus(
2927
val blockManagerId: BlockManagerId,
3028
val maxMem: Long,
3129
val blocks: mutable.Map[BlockId, BlockStatus] = mutable.Map.empty) {
3230

33-
def memUsed() = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
31+
def memUsed = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
3432

3533
def memUsedByRDD(rddId: Int) =
3634
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
3735

38-
def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
36+
def diskUsed = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
3937

4038
def diskUsedByRDD(rddId: Int) =
4139
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
4240

43-
def memRemaining : Long = maxMem - memUsed()
41+
def memRemaining: Long = maxMem - memUsed
4442

4543
def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) }
4644
}
4745

48-
@DeveloperApi
49-
private[spark]
50-
class RDDInfo(
51-
val id: Int,
52-
val name: String,
53-
val numPartitions: Int,
54-
val storageLevel: StorageLevel)
55-
extends Ordered[RDDInfo] {
56-
57-
var numCachedPartitions = 0
58-
var memSize = 0L
59-
var diskSize = 0L
60-
var tachyonSize = 0L
61-
62-
override def toString = {
63-
import Utils.bytesToString
64-
("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s;" +
65-
"TachyonSize: %s; DiskSize: %s").format(
66-
name, id, storageLevel.toString, numCachedPartitions, numPartitions,
67-
bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize))
68-
}
69-
70-
override def compare(that: RDDInfo) = {
71-
this.id - that.id
72-
}
73-
}
74-
75-
/* Helper methods for storage-related objects */
76-
private[spark]
77-
object StorageUtils {
46+
/** Helper methods for storage-related objects. */
47+
private[spark] object StorageUtils {
7848

7949
/**
8050
* Returns basic information of all RDDs persisted in the given SparkContext. This does not

core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
3232
def render(request: HttpServletRequest): Seq[Node] = {
3333
val storageStatusList = listener.storageStatusList
3434
val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_ + _)
35-
val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_ + _)
35+
val memUsed = storageStatusList.map(_.memUsed).fold(0L)(_ + _)
3636
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_ + _)
3737
val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId)
3838
val execInfoSorted = execInfo.sortBy(_.getOrElse("Executor ID", ""))
@@ -106,9 +106,9 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
106106
val execId = status.blockManagerId.executorId
107107
val hostPort = status.blockManagerId.hostPort
108108
val rddBlocks = status.blocks.size
109-
val memUsed = status.memUsed()
109+
val memUsed = status.memUsed
110110
val maxMem = status.maxMem
111-
val diskUsed = status.diskUsed()
111+
val diskUsed = status.diskUsed
112112
val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0)
113113
val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
114114
val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)

core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ private[ui] class StorageListener(storageStatusListener: StorageStatusListener)
6666
}
6767

6868
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
69-
val rddInfo = stageSubmitted.stageInfo.rddInfo
70-
_rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo)
69+
val rddInfos = stageSubmitted.stageInfo.rddInfos
70+
rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) }
7171
}
7272

7373
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ private[spark] object JsonProtocol {
176176
* -------------------------------------------------------------------- */
177177

178178
def stageInfoToJson(stageInfo: StageInfo): JValue = {
179-
val rddInfo = rddInfoToJson(stageInfo.rddInfo)
179+
val rddInfo = JArray(stageInfo.rddInfos.map(rddInfoToJson).toList)
180180
val submissionTime = stageInfo.submissionTime.map(JInt(_)).getOrElse(JNothing)
181181
val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing)
182182
val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing)
@@ -208,7 +208,8 @@ private[spark] object JsonProtocol {
208208
taskMetrics.shuffleReadMetrics.map(shuffleReadMetricsToJson).getOrElse(JNothing)
209209
val shuffleWriteMetrics =
210210
taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
211-
val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
211+
val updatedBlocks =
212+
taskMetrics.updatedBlocks.map { blocks =>
212213
JArray(blocks.toList.map { case (id, status) =>
213214
("Block ID" -> id.toString) ~
214215
("Status" -> blockStatusToJson(status))
@@ -467,13 +468,13 @@ private[spark] object JsonProtocol {
467468
val stageId = (json \ "Stage ID").extract[Int]
468469
val stageName = (json \ "Stage Name").extract[String]
469470
val numTasks = (json \ "Number of Tasks").extract[Int]
470-
val rddInfo = rddInfoFromJson(json \ "RDD Info")
471+
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
471472
val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
472473
val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
473474
val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
474475
val emittedTaskSizeWarning = (json \ "Emitted Task Size Warning").extract[Boolean]
475476

476-
val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfo)
477+
val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos)
477478
stageInfo.submissionTime = submissionTime
478479
stageInfo.completionTime = completionTime
479480
stageInfo.failureReason = failureReason
@@ -518,13 +519,14 @@ private[spark] object JsonProtocol {
518519
Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)
519520
metrics.shuffleWriteMetrics =
520521
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
521-
metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
522-
value.extract[List[JValue]].map { block =>
523-
val id = BlockId((block \ "Block ID").extract[String])
524-
val status = blockStatusFromJson(block \ "Status")
525-
(id, status)
522+
metrics.updatedBlocks =
523+
Utils.jsonOption(json \ "Updated Blocks").map { value =>
524+
value.extract[List[JValue]].map { block =>
525+
val id = BlockId((block \ "Block ID").extract[String])
526+
val status = blockStatusFromJson(block \ "Status")
527+
(id, status)
528+
}
526529
}
527-
}
528530
metrics
529531
}
530532

core/src/test/scala/org/apache/spark/CacheManagerSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
6060

6161
whenExecuting(blockManager) {
6262
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
63-
taskMetrics = TaskMetrics.empty())
63+
taskMetrics = TaskMetrics.empty)
6464
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
6565
assert(value.toList === List(1, 2, 3, 4))
6666
}
@@ -73,7 +73,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
7373

7474
whenExecuting(blockManager) {
7575
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
76-
taskMetrics = TaskMetrics.empty())
76+
taskMetrics = TaskMetrics.empty)
7777
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
7878
assert(value.toList === List(5, 6, 7))
7979
}
@@ -87,7 +87,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
8787

8888
whenExecuting(blockManager) {
8989
val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
90-
taskMetrics = TaskMetrics.empty())
90+
taskMetrics = TaskMetrics.empty)
9191
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
9292
assert(value.toList === List(1, 2, 3, 4))
9393
}

0 commit comments

Comments
 (0)