Skip to content

Commit 124568b

Browse files
committed
address comments
1 parent c74320d commit 124568b

File tree

23 files changed

+86
-53
lines changed

23 files changed

+86
-53
lines changed

core/src/main/scala/org/apache/spark/Accumulable.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Accumulable[R, T] private (
7272

7373
def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
7474

75+
val zero = param.zero(initialValue)
7576
private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
7677
newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
7778
// Register the new accumulator in ctor, to follow the previous behaviour.

core/src/main/scala/org/apache/spark/NewAccumulator.scala

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,7 @@ private[spark] case class AccumulatorMetadata(
3434

3535
/**
3636
* The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
37-
* type `OUT`. Implementations must define following methods:
38-
* - isZero: tell if this accumulator is zero value or not. e.g. for a counter accumulator,
39-
* 0 is zero value; for a list accumulator, Nil is zero value.
40-
* - copyAndReset: create a new copy of this accumulator, which is zero value. i.e. call `isZero`
41-
* on the copy must return true.
42-
* - add: defines how to accumulate the inputs. e.g. it can be a simple `+=` for counter
43-
* accumulator
44-
* - merge: defines how to merge another accumulator of same type.
45-
* - localValue: defines how to produce the output by the current state of this accumulator.
46-
*
47-
* The implementations decide how to store intermediate values, e.g. a long field for a counter
48-
* accumulator, a double and a long field for a average accumulator(storing the sum and count).
37+
* type `OUT`.
4938
*/
5039
abstract class NewAccumulator[IN, OUT] extends Serializable {
5140
private[spark] var metadata: AccumulatorMetadata = _
@@ -63,6 +52,10 @@ abstract class NewAccumulator[IN, OUT] extends Serializable {
6352
sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
6453
}
6554

55+
/**
56+
* Returns true if this accumulator has been registered. Note that all accumulators must be
57+
* registered before ues, or it will throw exception.
58+
*/
6659
final def isRegistered: Boolean =
6760
metadata != null && AccumulatorContext.originals.containsKey(metadata.id)
6861

@@ -72,38 +65,69 @@ abstract class NewAccumulator[IN, OUT] extends Serializable {
7265
}
7366
}
7467

68+
/**
69+
* Returns the id of this accumulator, can only be called after registration.
70+
*/
7571
final def id: Long = {
7672
assertMetadataNotNull()
7773
metadata.id
7874
}
7975

76+
/**
77+
* Returns the name of this accumulator, can only be called after registration.
78+
*/
8079
final def name: Option[String] = {
8180
assertMetadataNotNull()
8281
metadata.name
8382
}
8483

85-
final def countFailedValues: Boolean = {
84+
/**
85+
* Whether to accumulate values from failed tasks. This is set to true for system and time
86+
* metrics like serialization time or bytes spilled, and false for things with absolute values
87+
* like number of input rows. This should be used for internal metrics only.
88+
*/
89+
private[spark] final def countFailedValues: Boolean = {
8690
assertMetadataNotNull()
8791
metadata.countFailedValues
8892
}
8993

94+
/**
95+
* Creates an [[AccumulableInfo]] representation of this [[NewAccumulator]] with the provided
96+
* values.
97+
*/
9098
private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
9199
val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))
92100
new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
93101
}
94102

95103
final private[spark] def isAtDriverSide: Boolean = atDriverSide
96104

105+
/**
106+
* Tells if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
107+
* value; for a list accumulator, Nil is zero value.
108+
*/
97109
def isZero(): Boolean
98110

111+
/**
112+
* Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
113+
* must return true.
114+
*/
99115
def copyAndReset(): NewAccumulator[IN, OUT]
100116

117+
/**
118+
* Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator.
119+
*/
101120
def add(v: IN): Unit
102121

103-
def +=(v: IN): Unit = add(v)
104-
122+
/**
123+
* Merges another same-type accumulator into this one and update its state, i.e. this should be
124+
* merge-in-place.
125+
*/
105126
def merge(other: NewAccumulator[IN, OUT]): Unit
106127

128+
/**
129+
* Access this accumulator's current value; only allowed on driver.
130+
*/
107131
final def value: OUT = {
108132
if (atDriverSide) {
109133
localValue
@@ -112,6 +136,12 @@ abstract class NewAccumulator[IN, OUT] extends Serializable {
112136
}
113137
}
114138

139+
/**
140+
* Defines the current value of this accumulator.
141+
*
142+
* This is NOT the global value of the accumulator. To get the global value after a
143+
* completed operation on the dataset, call `value`.
144+
*/
115145
def localValue: OUT
116146

117147
// Called by Java when serializing an object
@@ -328,6 +358,11 @@ class ListAccumulator[T] extends NewAccumulator[T, java.util.List[T]] {
328358
}
329359

330360
override def localValue: java.util.List[T] = java.util.Collections.unmodifiableList(_list)
361+
362+
private[spark] def setValue(newValue: java.util.List[T]): Unit = {
363+
_list.clear()
364+
_list.addAll(newValue)
365+
}
331366
}
332367

333368

core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
5151
sc.addSparkListener(listener)
5252
// Have each task add 1 to the internal accumulator
5353
val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
54-
TaskContext.get().taskMetrics().testAccum.get += 1L
54+
TaskContext.get().taskMetrics().testAccum.get.add(1)
5555
iter
5656
}
5757
// Register asserts in job completion callback to avoid flakiness
@@ -87,17 +87,17 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
8787
val rdd = sc.parallelize(1 to 100, numPartitions)
8888
.map { i => (i, i) }
8989
.mapPartitions { iter =>
90-
TaskContext.get().taskMetrics().testAccum.get += 1L
90+
TaskContext.get().taskMetrics().testAccum.get.add(1)
9191
iter
9292
}
9393
.reduceByKey { case (x, y) => x + y }
9494
.mapPartitions { iter =>
95-
TaskContext.get().taskMetrics().testAccum.get += 10L
95+
TaskContext.get().taskMetrics().testAccum.get.add(10)
9696
iter
9797
}
9898
.repartition(numPartitions * 2)
9999
.mapPartitions { iter =>
100-
TaskContext.get().taskMetrics().testAccum.get += 100L
100+
TaskContext.get().taskMetrics().testAccum.get.add(100)
101101
iter
102102
}
103103
// Register asserts in job completion callback to avoid flakiness
@@ -127,7 +127,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
127127
// This should retry both stages in the scheduler. Note that we only want to fail the
128128
// first stage attempt because we want the stage to eventually succeed.
129129
val x = sc.parallelize(1 to 100, numPartitions)
130-
.mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get += 1L; iter }
130+
.mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get.add(1); iter }
131131
.groupBy(identity)
132132
val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
133133
val rdd = x.mapPartitionsWithIndex { case (i, iter) =>

core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,11 @@ package org.apache.spark.executor
2020
import org.scalatest.Assertions
2121

2222
import org.apache.spark._
23-
import org.apache.spark.scheduler.AccumulableInfo
24-
import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId}
23+
import org.apache.spark.storage.{BlockStatus, StorageLevel, TestBlockId}
2524

2625

2726
class TaskMetricsSuite extends SparkFunSuite {
2827
import StorageLevel._
29-
import TaskMetricsSuite._
3028

3129
test("mutating values") {
3230
val tm = new TaskMetrics
@@ -202,8 +200,8 @@ class TaskMetricsSuite extends SparkFunSuite {
202200
tm.registerAccumulator(acc2)
203201
tm.registerAccumulator(acc3)
204202
tm.registerAccumulator(acc4)
205-
acc1 += 1L
206-
acc2 += 2L
203+
acc1.add(1)
204+
acc2.add(2)
207205
val newUpdates = tm.accumulators()
208206
.map(a => (a.id, a.asInstanceOf[NewAccumulator[Any, Any]])).toMap
209207
assert(newUpdates.contains(acc1.id))

project/MimaExcludes.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,6 @@ object MimaExcludes {
676676
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable")
677677
) ++ Seq(
678678
// SPARK-14654: New accumulator API
679-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulable.zero"),
680679
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"),
681680
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"),
682681
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"),

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ private[sql] case class RDDScanExec(
106106
override val nodeName: String) extends LeafExecNode {
107107

108108
private[sql] override lazy val metrics = Map(
109-
"numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"))
109+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
110110

111111
protected override def doExecute(): RDD[InternalRow] = {
112112
val numOutputRows = longMetric("numOutputRows")
@@ -147,7 +147,7 @@ private[sql] case class RowDataSourceScanExec(
147147
extends DataSourceScanExec with CodegenSupport {
148148

149149
private[sql] override lazy val metrics =
150-
Map("numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"))
150+
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
151151

152152
val outputUnsafeRows = relation match {
153153
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -216,7 +216,7 @@ private[sql] case class BatchedDataSourceScanExec(
216216
extends DataSourceScanExec with CodegenSupport {
217217

218218
private[sql] override lazy val metrics =
219-
Map("numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"),
219+
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
220220
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
221221

222222
protected override def doExecute(): RDD[InternalRow] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ case class ExpandExec(
4040
extends UnaryExecNode with CodegenSupport {
4141

4242
private[sql] override lazy val metrics = Map(
43-
"numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"))
43+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
4444

4545
// The GroupExpressions can output data with arbitrary partitioning, so set it
4646
// as UNKNOWN partitioning

sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ case class GenerateExec(
5656
extends UnaryExecNode {
5757

5858
private[sql] override lazy val metrics = Map(
59-
"numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"))
59+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
6060

6161
override def producedAttributes: AttributeSet = AttributeSet(output)
6262

sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ private[sql] case class LocalTableScanExec(
3131
rows: Seq[InternalRow]) extends LeafExecNode {
3232

3333
private[sql] override lazy val metrics = Map(
34-
"numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"))
34+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
3535

3636
private val unsafeRows: Array[InternalRow] = {
3737
val proj = UnsafeProjection.create(output, output)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ case class SortBasedAggregateExec(
4646
AttributeSet(aggregateBufferAttributes)
4747

4848
override private[sql] lazy val metrics = Map(
49-
"numOutputRows" -> SQLMetrics.createSumMetric(sparkContext, "number of output rows"))
49+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
5050

5151
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
5252

0 commit comments

Comments
 (0)