Skip to content

Commit 5b65b1f

Browse files
committed
In memory shuffle (cherry-picked from amplab#135)
(cherry picked from commit 5ec645d) Conflicts: core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala core/src/main/scala/org/apache/spark/storage/BlockManager.scala graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
1 parent 777654e commit 5b65b1f

File tree

4 files changed

+38
-15
lines changed

4 files changed

+38
-15
lines changed

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.{FileOutputStream, File, OutputStream}
20+
import java.io.{ByteArrayOutputStream, FileOutputStream, File, OutputStream}
2121
import java.nio.channels.FileChannel
2222

2323
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
@@ -44,7 +44,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
4444
* Flush the partial writes and commit them as a single atomic block. Return the
4545
* number of bytes written for this commit.
4646
*/
47-
def commit(): Long
47+
def commit(): Array[Byte]
4848

4949
/**
5050
* Reverts writes that haven't been flushed yet. Callers should invoke this function
@@ -106,7 +106,7 @@ private[spark] class DiskBlockObjectWriter(
106106
/** The file channel, used for repositioning / truncating the file. */
107107
private var channel: FileChannel = null
108108
private var bs: OutputStream = null
109-
private var fos: FileOutputStream = null
109+
private var fos: ByteArrayOutputStream = null
110110
private var ts: TimeTrackingOutputStream = null
111111
private var objOut: SerializationStream = null
112112
private val initialPosition = file.length()
@@ -115,9 +115,8 @@ private[spark] class DiskBlockObjectWriter(
115115
private var _timeWriting = 0L
116116

117117
override def open(): BlockObjectWriter = {
118-
fos = new FileOutputStream(file, true)
118+
fos = new ByteArrayOutputStream()
119119
ts = new TimeTrackingOutputStream(fos)
120-
channel = fos.getChannel()
121120
lastValidPosition = initialPosition
122121
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
123122
objOut = serializer.newInstance().serializeStream(bs)
@@ -130,9 +129,6 @@ private[spark] class DiskBlockObjectWriter(
130129
if (syncWrites) {
131130
// Force outstanding writes to disk and track how long it takes
132131
objOut.flush()
133-
val start = System.nanoTime()
134-
fos.getFD.sync()
135-
_timeWriting += System.nanoTime() - start
136132
}
137133
objOut.close()
138134

@@ -149,18 +145,18 @@ private[spark] class DiskBlockObjectWriter(
149145

150146
override def isOpen: Boolean = objOut != null
151147

152-
override def commit(): Long = {
148+
override def commit(): Array[Byte] = {
153149
if (initialized) {
154150
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
155151
// serializer stream and the lower level stream.
156152
objOut.flush()
157153
bs.flush()
158154
val prevPos = lastValidPosition
159-
lastValidPosition = channel.position()
160-
lastValidPosition - prevPos
155+
lastValidPosition = fos.size()
156+
fos.toByteArray
161157
} else {
162158
// lastValidPosition is zero if stream is uninitialized
163-
lastValidPosition
159+
null
164160
}
165161
}
166162

@@ -170,7 +166,7 @@ private[spark] class DiskBlockObjectWriter(
170166
// truncate the file to the last valid position.
171167
objOut.flush()
172168
bs.flush()
173-
channel.truncate(lastValidPosition)
169+
throw new UnsupportedOperationException("Revert temporarily broken due to in memory shuffle code changes.")
174170
}
175171
}
176172

@@ -182,7 +178,7 @@ private[spark] class DiskBlockObjectWriter(
182178
}
183179

184180
override def fileSegment(): FileSegment = {
185-
new FileSegment(file, initialPosition, bytesWritten)
181+
new FileSegment(null, initialPosition, bytesWritten)
186182
}
187183

188184
// Only valid if called after close()

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.LinkedHashMap
2323
import scala.collection.mutable.ArrayBuffer
2424

2525
import org.apache.spark.util.{SizeEstimator, Utils}
26+
import org.apache.spark.serializer.Serializer
2627

2728
/**
2829
* Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as
@@ -119,6 +120,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
119120
}
120121
}
121122

123+
/**
124+
* A version of getValues that allows a custom serializer. This is used as part of the
125+
* shuffle short-circuit code.
126+
*/
127+
def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
128+
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
129+
}
130+
122131
override def remove(blockId: BlockId): Boolean = {
123132
entries.synchronized {
124133
val entry = entries.remove(blockId)

core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
207207
private def cleanup(cleanupTime: Long) {
208208
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
209209
}
210+
211+
def removeAllShuffleStuff() {
212+
for (state <- shuffleStates.values;
213+
group <- state.allFileGroups;
214+
(mapId, _) <- group.mapIdToIndex.iterator;
215+
reducerId <- 0 until group.files.length) {
216+
val blockId = new ShuffleBlockId(group.shuffleId, mapId, reducerId)
217+
blockManager.removeBlock(blockId, tellMaster = false)
218+
}
219+
shuffleStates.clear()
220+
}
210221
}
211222

212223
private[spark]
@@ -220,7 +231,7 @@ object ShuffleBlockManager {
220231
* Stores the absolute index of each mapId in the files of this group. For instance,
221232
* if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
222233
*/
223-
private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
234+
val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
224235

225236
/**
226237
* Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.

graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.graphx
1919

2020
import scala.reflect.ClassTag
21+
import org.apache.spark.SparkEnv
2122

2223

2324
/**
@@ -142,6 +143,12 @@ object Pregel {
142143
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
143144
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
144145
activeMessages = messages.count()
146+
147+
// Very ugly code to clear the in-memory shuffle data
148+
messages.foreachPartition { iter =>
149+
SparkEnv.get.blockManager.shuffleBlockManager.removeAllShuffleStuff()
150+
}
151+
145152
// Unpersist the RDDs hidden by newly-materialized RDDs
146153
oldMessages.unpersist(blocking=false)
147154
newVerts.unpersist(blocking=false)

0 commit comments

Comments
 (0)