Skip to content

Commit 470ec7b

Browse files
committed
New accumulator API
1 parent 6ab4d9e commit 470ec7b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+832
-579
lines changed

core/src/main/scala/org/apache/spark/Accumulable.scala

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Accumulable[R, T] private (
6363
param: AccumulableParam[R, T],
6464
name: Option[String],
6565
countFailedValues: Boolean) = {
66-
this(Accumulators.newId(), initialValue, param, name, countFailedValues)
66+
this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues)
6767
}
6868

6969
private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = {
@@ -72,61 +72,43 @@ class Accumulable[R, T] private (
7272

7373
def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
7474

75-
@volatile @transient private var value_ : R = initialValue // Current value on driver
76-
val zero = param.zero(initialValue) // Zero value to be passed to executors
77-
private var deserialized = false
78-
79-
Accumulators.register(this)
80-
81-
/**
82-
* Return a copy of this [[Accumulable]].
83-
*
84-
* The copy will have the same ID as the original and will not be registered with
85-
* [[Accumulators]] again. This method exists so that the caller can avoid passing the
86-
* same mutable instance around.
87-
*/
88-
private[spark] def copy(): Accumulable[R, T] = {
89-
new Accumulable[R, T](id, initialValue, param, name, countFailedValues)
90-
}
75+
private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
76+
newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
77+
// Register the new accumulator in ctor, to follow the previous behaviour.
78+
AccumulatorContext.register(newAcc)
9179

9280
/**
9381
* Add more data to this accumulator / accumulable
9482
* @param term the data to add
9583
*/
96-
def += (term: T) { value_ = param.addAccumulator(value_, term) }
84+
def += (term: T) { newAcc.add(term) }
9785

9886
/**
9987
* Add more data to this accumulator / accumulable
10088
* @param term the data to add
10189
*/
102-
def add(term: T) { value_ = param.addAccumulator(value_, term) }
90+
def add(term: T) { newAcc.add(term) }
10391

10492
/**
10593
* Merge two accumulable objects together
10694
*
10795
* Normally, a user will not want to use this version, but will instead call `+=`.
10896
* @param term the other `R` that will get merged with this
10997
*/
110-
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
98+
def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
11199

112100
/**
113101
* Merge two accumulable objects together
114102
*
115103
* Normally, a user will not want to use this version, but will instead call `add`.
116104
* @param term the other `R` that will get merged with this
117105
*/
118-
def merge(term: R) { value_ = param.addInPlace(value_, term)}
106+
def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
119107

120108
/**
121109
* Access the accumulator's current value; only allowed on driver.
122110
*/
123-
def value: R = {
124-
if (!deserialized) {
125-
value_
126-
} else {
127-
throw new UnsupportedOperationException("Can't read accumulator value in task")
128-
}
129-
}
111+
def value: R = newAcc.value
130112

131113
/**
132114
* Get the current value of this accumulator from within a task.
@@ -137,14 +119,14 @@ class Accumulable[R, T] private (
137119
* The typical use of this method is to directly mutate the local value, eg., to add
138120
* an element to a Set.
139121
*/
140-
def localValue: R = value_
122+
def localValue: R = newAcc.localValue
141123

142124
/**
143125
* Set the accumulator's value; only allowed on driver.
144126
*/
145127
def value_= (newValue: R) {
146-
if (!deserialized) {
147-
value_ = newValue
128+
if (newAcc.isAtDriverSide) {
129+
newAcc._value = newValue
148130
} else {
149131
throw new UnsupportedOperationException("Can't assign accumulator value in task")
150132
}
@@ -153,7 +135,7 @@ class Accumulable[R, T] private (
153135
/**
154136
* Set the accumulator's value. For internal use only.
155137
*/
156-
def setValue(newValue: R): Unit = { value_ = newValue }
138+
def setValue(newValue: R): Unit = { newAcc._value = newValue }
157139

158140
/**
159141
* Set the accumulator's value. For internal use only.
@@ -168,22 +150,7 @@ class Accumulable[R, T] private (
168150
new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
169151
}
170152

171-
// Called by Java when deserializing an object
172-
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
173-
in.defaultReadObject()
174-
value_ = zero
175-
deserialized = true
176-
177-
// Automatically register the accumulator when it is deserialized with the task closure.
178-
// This is for external accumulators and internal ones that do not represent task level
179-
// metrics, e.g. internal SQL metrics, which are per-operator.
180-
val taskContext = TaskContext.get()
181-
if (taskContext != null) {
182-
taskContext.registerAccumulator(this)
183-
}
184-
}
185-
186-
override def toString: String = if (value_ == null) "null" else value_.toString
153+
override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString
187154
}
188155

189156

core/src/main/scala/org/apache/spark/Accumulator.scala

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -68,73 +68,6 @@ class Accumulator[T] private[spark] (
6868
extends Accumulable[T, T](initialValue, param, name, countFailedValues)
6969

7070

71-
// TODO: The multi-thread support in accumulators is kind of lame; check
72-
// if there's a more intuitive way of doing it right
73-
private[spark] object Accumulators extends Logging {
74-
/**
75-
* This global map holds the original accumulator objects that are created on the driver.
76-
* It keeps weak references to these objects so that accumulators can be garbage-collected
77-
* once the RDDs and user-code that reference them are cleaned up.
78-
* TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
79-
*/
80-
@GuardedBy("Accumulators")
81-
val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
82-
83-
private val nextId = new AtomicLong(0L)
84-
85-
/**
86-
* Return a globally unique ID for a new [[Accumulable]].
87-
* Note: Once you copy the [[Accumulable]] the ID is no longer unique.
88-
*/
89-
def newId(): Long = nextId.getAndIncrement
90-
91-
/**
92-
* Register an [[Accumulable]] created on the driver such that it can be used on the executors.
93-
*
94-
* All accumulators registered here can later be used as a container for accumulating partial
95-
* values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
96-
* Note: if an accumulator is registered here, it should also be registered with the active
97-
* context cleaner for cleanup so as to avoid memory leaks.
98-
*
99-
* If an [[Accumulable]] with the same ID was already registered, this does nothing instead
100-
* of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
101-
* [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
102-
*/
103-
def register(a: Accumulable[_, _]): Unit = synchronized {
104-
if (!originals.contains(a.id)) {
105-
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
106-
}
107-
}
108-
109-
/**
110-
* Unregister the [[Accumulable]] with the given ID, if any.
111-
*/
112-
def remove(accId: Long): Unit = synchronized {
113-
originals.remove(accId)
114-
}
115-
116-
/**
117-
* Return the [[Accumulable]] registered with the given ID, if any.
118-
*/
119-
def get(id: Long): Option[Accumulable[_, _]] = synchronized {
120-
originals.get(id).map { weakRef =>
121-
// Since we are storing weak references, we must check whether the underlying data is valid.
122-
weakRef.get.getOrElse {
123-
throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
124-
}
125-
}
126-
}
127-
128-
/**
129-
* Clear all registered [[Accumulable]]s. For testing only.
130-
*/
131-
def clear(): Unit = synchronized {
132-
originals.clear()
133-
}
134-
135-
}
136-
137-
13871
/**
13972
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add
14073
* in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
144144
registerForCleanup(rdd, CleanRDD(rdd.id))
145145
}
146146

147-
def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
147+
def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
148148
registerForCleanup(a, CleanAccum(a.id))
149149
}
150150

@@ -241,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
241241
def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
242242
try {
243243
logDebug("Cleaning accumulator " + accId)
244-
Accumulators.remove(accId)
244+
AccumulatorContext.remove(accId)
245245
listeners.asScala.foreach(_.accumCleaned(accId))
246246
logInfo("Cleaned accumulator " + accId)
247247
} catch {

core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
3535
*/
3636
private[spark] case class Heartbeat(
3737
executorId: String,
38-
accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
38+
accumUpdates: Array[(Long, Seq[AccumulatorUpdates])], // taskId -> accumulator updates
3939
blockManagerId: BlockManagerId)
4040

4141
/**

0 commit comments

Comments
 (0)