1717
1818package org .apache .spark .shuffle .unsafe
1919
20- import java .io .{ByteArrayOutputStream , FileOutputStream }
20+ import java .io .{FileOutputStream , OutputStream }
2121import java .nio .ByteBuffer
2222import java .util
2323
2424import com .esotericsoftware .kryo .io .ByteBufferOutputStream
25+
26+ import org .apache .spark .{ShuffleDependency , SparkConf , SparkEnv , TaskContext }
2527import org .apache .spark .executor .ShuffleWriteMetrics
26- import org .apache .spark .network .buffer .ManagedBuffer
2728import org .apache .spark .scheduler .MapStatus
2829import org .apache .spark .serializer .Serializer
30+ import org .apache .spark .shuffle ._
2931import org .apache .spark .shuffle .sort .SortShuffleManager
3032import org .apache .spark .storage .ShuffleBlockId
3133import org .apache .spark .unsafe .PlatformDependent
32- import org .apache .spark .unsafe .memory .{TaskMemoryManager , MemoryBlock }
34+ import org .apache .spark .unsafe .memory .{MemoryBlock , TaskMemoryManager }
3335import org .apache .spark .unsafe .sort .UnsafeSorter
3436import org .apache .spark .unsafe .sort .UnsafeSorter .{KeyPointerAndPrefix , PrefixComparator , PrefixComputer , RecordComparator }
35- import org .apache .spark .{SparkEnv , SparkConf , ShuffleDependency , TaskContext }
36- import org .apache .spark .shuffle ._
3737
3838private [spark] class UnsafeShuffleHandle [K , V ](
3939 shuffleId : Int ,
@@ -87,7 +87,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
8787
8888 private [this ] val dep = handle.dependency
8989
90- private [this ] var sorter : UnsafeSorter = null
90+ private [this ] val partitioner = dep.partitioner
9191
9292 // Are we in the process of stopping? Because map tasks can call stop() with success = true
9393 // and then call stop() with success = false if they get an exception, we want to make sure
@@ -104,52 +104,55 @@ private[spark] class UnsafeShuffleWriter[K, V](
104104
105105 private [this ] val blockManager = SparkEnv .get.blockManager
106106
107- /** Write a sequence of records to this task's output */
108- override def write (records : Iterator [_ <: Product2 [K , V ]]): Unit = {
109- println(" Opened writer!" )
110- val serializer = Serializer .getSerializer(dep.serializer).newInstance()
111- val partitioner = dep.partitioner
112- sorter = new UnsafeSorter (
107+ private def sortRecords (records : Iterator [_ <: Product2 [K , V ]]): java.util.Iterator [KeyPointerAndPrefix ] = {
108+ val sorter = new UnsafeSorter (
113109 context.taskMemoryManager(),
114110 DummyRecordComparator ,
115111 PartitionerPrefixComputer ,
116112 PartitionerPrefixComparator ,
117113 4096 // initial size
118114 )
119-
120- // Pack records into data pages:
115+ val serializer = Serializer .getSerializer(dep.serializer).newInstance()
121116 val PAGE_SIZE = 1024 * 1024 * 1
117+
122118 var currentPage : MemoryBlock = memoryManager.allocatePage(PAGE_SIZE )
123- allocatedPages.add(currentPage)
124119 var currentPagePosition : Long = currentPage.getBaseOffset
125120
126- // TODO make this configurable
121+ def ensureSpaceInDataPage (spaceRequired : Long ): Unit = {
122+ if (spaceRequired > PAGE_SIZE ) {
123+ throw new Exception (s " Size requirement $spaceRequired is greater than page size $PAGE_SIZE" )
124+ } else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) {
125+ currentPage = memoryManager.allocatePage(PAGE_SIZE )
126+ allocatedPages.add(currentPage)
127+ currentPagePosition = currentPage.getBaseOffset
128+ }
129+ }
130+
131+ // TODO: the size of this buffer should be configurable
127132 val serArray = new Array [Byte ](1024 * 1024 )
128133 val byteBuffer = ByteBuffer .wrap(serArray)
129134 val bbos = new ByteBufferOutputStream ()
130135 bbos.setByteBuffer(byteBuffer)
131136 val serBufferSerStream = serializer.serializeStream(bbos)
132137
133- while (records.hasNext) {
134- val nextRecord : Product2 [K , V ] = records.next()
135- println(" Writing record " + nextRecord)
136- val partitionId : Int = partitioner.getPartition(nextRecord._1)
137- serBufferSerStream.writeObject(nextRecord)
138-
139- val sizeRequirement : Int = byteBuffer.position() + 8 + 8
140- println(" Size requirement in intenral buffer is " + sizeRequirement)
141- if (sizeRequirement > (PAGE_SIZE - currentPagePosition)) {
142- println(" Allocating a new data page after writing " + currentPagePosition)
143- currentPage = memoryManager.allocatePage(PAGE_SIZE )
144- allocatedPages.add(currentPage)
145- currentPagePosition = currentPage.getBaseOffset
146- }
147- println(" Before writing record, current page position is " + currentPagePosition)
148- // TODO: check that it's still not too large
138+ def writeRecord (record : Product2 [Any , Any ]): Unit = {
139+ val (key, value) = record
140+ val partitionId = partitioner.getPartition(key)
141+ serBufferSerStream.writeKey(key)
142+ serBufferSerStream.writeValue(value)
143+ serBufferSerStream.flush()
144+
145+ val serializedRecordSize = byteBuffer.position()
146+ // TODO: we should run the partition extraction function _now_, at insert time, rather than
147+ // requiring it to be stored alongisde the data, since this may lead to double storage
148+ val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
149+ ensureSpaceInDataPage(sizeRequirementInSortDataPage)
150+
149151 val newRecordAddress =
150152 memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
151153 PlatformDependent .UNSAFE .putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
152154 currentPagePosition += 8
155+ println(" The stored record length is " + byteBuffer.position())
153156 PlatformDependent .UNSAFE .putLong(
154157 currentPage.getBaseObject, currentPagePosition, byteBuffer.position())
155158 currentPagePosition += 8
@@ -162,45 +165,53 @@ private[spark] class UnsafeShuffleWriter[K, V](
162165 currentPagePosition += byteBuffer.position()
163166 println(" After writing record, current page position is " + currentPagePosition)
164167 sorter.insertRecord(newRecordAddress)
168+
169+ // Reset for writing the next record
165170 byteBuffer.position(0 )
166171 }
167- // TODO: free the buffers, etc, at this point since they're not needed
168- val sortedIterator : util.Iterator [KeyPointerAndPrefix ] = sorter.getSortedIterator
169- // Now that the partition is sorted, write out the data to a file, keeping track off offsets
170- // for use in the sort-based shuffle index.
172+
173+ while (records.hasNext) {
174+ writeRecord(records.next())
175+ }
176+
177+ sorter.getSortedIterator
178+ }
179+
180+ private def writeSortedRecordsToFile (sortedRecords : java.util.Iterator [KeyPointerAndPrefix ]): Array [Long ] = {
171181 val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
172182 val blockId = ShuffleBlockId (dep.shuffleId, mapId, IndexShuffleBlockManager .NOOP_REDUCE_ID )
173183 val partitionLengths = new Array [Long ](partitioner.numPartitions)
174- // TODO: compression tests?
175- // TODO why is append true here?
176- // TODO: metrics tracking and all of the other stuff that diskblockobjectwriter would give us
177- // TODO: note that we saw FAILED_TO_UNCOMPRESS(5) at some points during debugging when we were
178- // not properly wrapping the writer for compression even though readers expected compressed
179- // data; the fact that someone still reported this issue in newer Spark versions suggests that
180- // we should audit the code to make sure wrapping is done at the right set of places and to
181- // check that we haven't missed any rare corner-cases / rarely-used paths.
182- val out = blockManager.wrapForCompression(blockId, new FileOutputStream (outputFile, true ))
183- val serOut = serializer.serializeStream(out)
184- serOut.flush()
184+
185185 var currentPartition = - 1
186- var currentPartitionLength : Long = 0
187- while (sortedIterator.hasNext) {
188- val keyPointerAndPrefix : KeyPointerAndPrefix = sortedIterator.next()
189- val partition = keyPointerAndPrefix.keyPrefix.toInt
190- println(" Partition is " + partition)
191- if (currentPartition == - 1 ) {
192- currentPartition = partition
186+ var prevPartitionLength : Long = 0
187+ var out : OutputStream = null
188+
189+ // TODO: don't close and re-open file handles so often; this could be inefficient
190+
191+ def closePartition (): Unit = {
192+ out.flush()
193+ out.close()
194+ partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength
195+ }
196+
197+ def switchToPartition (newPartition : Int ): Unit = {
198+ if (currentPartition != - 1 ) {
199+ closePartition()
200+ prevPartitionLength = partitionLengths(currentPartition)
193201 }
202+ currentPartition = newPartition
203+ out = blockManager.wrapForCompression(blockId, new FileOutputStream (outputFile, true ))
204+ }
205+
206+ while (sortedRecords.hasNext) {
207+ val keyPointerAndPrefix : KeyPointerAndPrefix = sortedRecords.next()
208+ val partition = keyPointerAndPrefix.keyPrefix.toInt
194209 if (partition != currentPartition) {
195- println(" switching partition" )
196- partitionLengths(currentPartition) = currentPartitionLength
197- currentPartitionLength = 0
198- currentPartition = partition
210+ switchToPartition(partition)
199211 }
200212 val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
201213 val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
202214 val recordLength = PlatformDependent .UNSAFE .getLong(baseObject, baseOffset + 8 )
203- partitionLengths(currentPartition) += recordLength
204215 println(" Base offset is " + baseOffset)
205216 println(" Record length is " + recordLength)
206217 var i : Int = 0
@@ -213,10 +224,19 @@ private[spark] class UnsafeShuffleWriter[K, V](
213224 i += 1
214225 }
215226 }
216- out.flush()
217- // serOut.close()
218- // out.flush()
219- out.close()
227+ closePartition()
228+
229+ partitionLengths
230+ }
231+
232+ /** Write a sequence of records to this task's output */
233+ override def write (records : Iterator [_ <: Product2 [K , V ]]): Unit = {
234+ println(" Opened writer!" )
235+
236+ val sortedIterator = sortRecords(records)
237+ val partitionLengths = writeSortedRecordsToFile(sortedIterator)
238+
239+ println(" Partition lengths are " + partitionLengths.toSeq)
220240 shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
221241 mapStatus = MapStatus (blockManager.shuffleServerId, partitionLengths)
222242 }
@@ -239,7 +259,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
239259 }
240260 } finally {
241261 // Clean up our sorter, which may have its own intermediate files
242- if (sorter != null ) {
262+ if (! allocatedPages.isEmpty ) {
243263 val iter = allocatedPages.iterator()
244264 while (iter.hasNext) {
245265 memoryManager.freePage(iter.next())
@@ -249,7 +269,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
249269 // sorter.stop()
250270 context.taskMetrics().shuffleWriteMetrics.foreach(
251271 _.incShuffleWriteTime(System .nanoTime - startTime))
252- sorter = null
253272 }
254273 }
255274 }
0 commit comments