Skip to content

Commit 8e3ec20

Browse files
committed
Begin code cleanup.
1 parent 4d2f5e1 commit 8e3ec20

File tree

1 file changed

+84
-65
lines changed

1 file changed

+84
-65
lines changed

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 84 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,23 @@
1717

1818
package org.apache.spark.shuffle.unsafe
1919

20-
import java.io.{ByteArrayOutputStream, FileOutputStream}
20+
import java.io.{FileOutputStream, OutputStream}
2121
import java.nio.ByteBuffer
2222
import java.util
2323

2424
import com.esotericsoftware.kryo.io.ByteBufferOutputStream
25+
26+
import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
2527
import org.apache.spark.executor.ShuffleWriteMetrics
26-
import org.apache.spark.network.buffer.ManagedBuffer
2728
import org.apache.spark.scheduler.MapStatus
2829
import org.apache.spark.serializer.Serializer
30+
import org.apache.spark.shuffle._
2931
import org.apache.spark.shuffle.sort.SortShuffleManager
3032
import org.apache.spark.storage.ShuffleBlockId
3133
import org.apache.spark.unsafe.PlatformDependent
32-
import org.apache.spark.unsafe.memory.{TaskMemoryManager, MemoryBlock}
34+
import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
3335
import org.apache.spark.unsafe.sort.UnsafeSorter
3436
import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator}
35-
import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, TaskContext}
36-
import org.apache.spark.shuffle._
3737

3838
private[spark] class UnsafeShuffleHandle[K, V](
3939
shuffleId: Int,
@@ -87,7 +87,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
8787

8888
private[this] val dep = handle.dependency
8989

90-
private[this] var sorter: UnsafeSorter = null
90+
private[this] val partitioner = dep.partitioner
9191

9292
// Are we in the process of stopping? Because map tasks can call stop() with success = true
9393
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -104,52 +104,55 @@ private[spark] class UnsafeShuffleWriter[K, V](
104104

105105
private[this] val blockManager = SparkEnv.get.blockManager
106106

107-
/** Write a sequence of records to this task's output */
108-
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
109-
println("Opened writer!")
110-
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
111-
val partitioner = dep.partitioner
112-
sorter = new UnsafeSorter(
107+
private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = {
108+
val sorter = new UnsafeSorter(
113109
context.taskMemoryManager(),
114110
DummyRecordComparator,
115111
PartitionerPrefixComputer,
116112
PartitionerPrefixComparator,
117113
4096 // initial size
118114
)
119-
120-
// Pack records into data pages:
115+
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
121116
val PAGE_SIZE = 1024 * 1024 * 1
117+
122118
var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE)
123-
allocatedPages.add(currentPage)
124119
var currentPagePosition: Long = currentPage.getBaseOffset
125120

126-
// TODO make this configurable
121+
def ensureSpaceInDataPage(spaceRequired: Long): Unit = {
122+
if (spaceRequired > PAGE_SIZE) {
123+
throw new Exception(s"Size requirement $spaceRequired is greater than page size $PAGE_SIZE")
124+
} else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) {
125+
currentPage = memoryManager.allocatePage(PAGE_SIZE)
126+
allocatedPages.add(currentPage)
127+
currentPagePosition = currentPage.getBaseOffset
128+
}
129+
}
130+
131+
// TODO: the size of this buffer should be configurable
127132
val serArray = new Array[Byte](1024 * 1024)
128133
val byteBuffer = ByteBuffer.wrap(serArray)
129134
val bbos = new ByteBufferOutputStream()
130135
bbos.setByteBuffer(byteBuffer)
131136
val serBufferSerStream = serializer.serializeStream(bbos)
132137

133-
while (records.hasNext) {
134-
val nextRecord: Product2[K, V] = records.next()
135-
println("Writing record " + nextRecord)
136-
val partitionId: Int = partitioner.getPartition(nextRecord._1)
137-
serBufferSerStream.writeObject(nextRecord)
138-
139-
val sizeRequirement: Int = byteBuffer.position() + 8 + 8
140-
println("Size requirement in intenral buffer is " + sizeRequirement)
141-
if (sizeRequirement > (PAGE_SIZE - currentPagePosition)) {
142-
println("Allocating a new data page after writing " + currentPagePosition)
143-
currentPage = memoryManager.allocatePage(PAGE_SIZE)
144-
allocatedPages.add(currentPage)
145-
currentPagePosition = currentPage.getBaseOffset
146-
}
147-
println("Before writing record, current page position is " + currentPagePosition)
148-
// TODO: check that it's still not too large
138+
def writeRecord(record: Product2[Any, Any]): Unit = {
139+
val (key, value) = record
140+
val partitionId = partitioner.getPartition(key)
141+
serBufferSerStream.writeKey(key)
142+
serBufferSerStream.writeValue(value)
143+
serBufferSerStream.flush()
144+
145+
val serializedRecordSize = byteBuffer.position()
146+
// TODO: we should run the partition extraction function _now_, at insert time, rather than
147+
// requiring it to be stored alongisde the data, since this may lead to double storage
148+
val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
149+
ensureSpaceInDataPage(sizeRequirementInSortDataPage)
150+
149151
val newRecordAddress =
150152
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
151153
PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
152154
currentPagePosition += 8
155+
println("The stored record length is " + byteBuffer.position())
153156
PlatformDependent.UNSAFE.putLong(
154157
currentPage.getBaseObject, currentPagePosition, byteBuffer.position())
155158
currentPagePosition += 8
@@ -162,45 +165,53 @@ private[spark] class UnsafeShuffleWriter[K, V](
162165
currentPagePosition += byteBuffer.position()
163166
println("After writing record, current page position is " + currentPagePosition)
164167
sorter.insertRecord(newRecordAddress)
168+
169+
// Reset for writing the next record
165170
byteBuffer.position(0)
166171
}
167-
// TODO: free the buffers, etc, at this point since they're not needed
168-
val sortedIterator: util.Iterator[KeyPointerAndPrefix] = sorter.getSortedIterator
169-
// Now that the partition is sorted, write out the data to a file, keeping track off offsets
170-
// for use in the sort-based shuffle index.
172+
173+
while (records.hasNext) {
174+
writeRecord(records.next())
175+
}
176+
177+
sorter.getSortedIterator
178+
}
179+
180+
private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = {
171181
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
172182
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID)
173183
val partitionLengths = new Array[Long](partitioner.numPartitions)
174-
// TODO: compression tests?
175-
// TODO why is append true here?
176-
// TODO: metrics tracking and all of the other stuff that diskblockobjectwriter would give us
177-
// TODO: note that we saw FAILED_TO_UNCOMPRESS(5) at some points during debugging when we were
178-
// not properly wrapping the writer for compression even though readers expected compressed
179-
// data; the fact that someone still reported this issue in newer Spark versions suggests that
180-
// we should audit the code to make sure wrapping is done at the right set of places and to
181-
// check that we haven't missed any rare corner-cases / rarely-used paths.
182-
val out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
183-
val serOut = serializer.serializeStream(out)
184-
serOut.flush()
184+
185185
var currentPartition = -1
186-
var currentPartitionLength: Long = 0
187-
while (sortedIterator.hasNext) {
188-
val keyPointerAndPrefix: KeyPointerAndPrefix = sortedIterator.next()
189-
val partition = keyPointerAndPrefix.keyPrefix.toInt
190-
println("Partition is " + partition)
191-
if (currentPartition == -1) {
192-
currentPartition = partition
186+
var prevPartitionLength: Long = 0
187+
var out: OutputStream = null
188+
189+
// TODO: don't close and re-open file handles so often; this could be inefficient
190+
191+
def closePartition(): Unit = {
192+
out.flush()
193+
out.close()
194+
partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength
195+
}
196+
197+
def switchToPartition(newPartition: Int): Unit = {
198+
if (currentPartition != -1) {
199+
closePartition()
200+
prevPartitionLength = partitionLengths(currentPartition)
193201
}
202+
currentPartition = newPartition
203+
out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
204+
}
205+
206+
while (sortedRecords.hasNext) {
207+
val keyPointerAndPrefix: KeyPointerAndPrefix = sortedRecords.next()
208+
val partition = keyPointerAndPrefix.keyPrefix.toInt
194209
if (partition != currentPartition) {
195-
println("switching partition")
196-
partitionLengths(currentPartition) = currentPartitionLength
197-
currentPartitionLength = 0
198-
currentPartition = partition
210+
switchToPartition(partition)
199211
}
200212
val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
201213
val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
202214
val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8)
203-
partitionLengths(currentPartition) += recordLength
204215
println("Base offset is " + baseOffset)
205216
println("Record length is " + recordLength)
206217
var i: Int = 0
@@ -213,10 +224,19 @@ private[spark] class UnsafeShuffleWriter[K, V](
213224
i += 1
214225
}
215226
}
216-
out.flush()
217-
//serOut.close()
218-
//out.flush()
219-
out.close()
227+
closePartition()
228+
229+
partitionLengths
230+
}
231+
232+
/** Write a sequence of records to this task's output */
233+
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
234+
println("Opened writer!")
235+
236+
val sortedIterator = sortRecords(records)
237+
val partitionLengths = writeSortedRecordsToFile(sortedIterator)
238+
239+
println("Partition lengths are " + partitionLengths.toSeq)
220240
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
221241
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
222242
}
@@ -239,7 +259,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
239259
}
240260
} finally {
241261
// Clean up our sorter, which may have its own intermediate files
242-
if (sorter != null) {
262+
if (!allocatedPages.isEmpty) {
243263
val iter = allocatedPages.iterator()
244264
while (iter.hasNext) {
245265
memoryManager.freePage(iter.next())
@@ -249,7 +269,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
249269
//sorter.stop()
250270
context.taskMetrics().shuffleWriteMetrics.foreach(
251271
_.incShuffleWriteTime(System.nanoTime - startTime))
252-
sorter = null
253272
}
254273
}
255274
}

0 commit comments

Comments
 (0)