Skip to content

Commit fcd9a3c

Browse files
committed
Add notes + tests for maximum record / page sizes.
1 parent 9d1ee7c commit fcd9a3c

File tree

8 files changed

+113
-40
lines changed

8 files changed

+113
-40
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,24 @@
1919

2020
/**
2121
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
22+
* <p>
23+
* Within the long, the data is laid out as follows:
24+
* <pre>
25+
* [24 bit partition number][13 bit memory page number][27 bit offset in page]
26+
* </pre>
27+
* This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
28+
* our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
29+
* 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
30+
* implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
31+
* <p>
32+
* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
33+
* optimization to future work as it will require more careful design to ensure that addresses are
34+
* properly aligned (e.g. by padding records).
2235
*/
2336
final class PackedRecordPointer {
2437

38+
static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes
39+
2540
/** Bit mask for the lower 40 bits of a long. */
2641
private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL;
2742

@@ -55,7 +70,11 @@ public static long packPointer(long recordPointer, int partitionId) {
5570
return (((long) partitionId) << 40) | compressedAddress;
5671
}
5772

58-
public long packedRecordPointer;
73+
private long packedRecordPointer;
74+
75+
public void set(long packedRecordPointer) {
76+
this.packedRecordPointer = packedRecordPointer;
77+
}
5978

6079
public int getPartitionId() {
6180
return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
@@ -68,7 +87,4 @@ public long getRecordPointer() {
6887
return pageNumber | offsetInPage;
6988
}
7089

71-
public int getRecordLength() {
72-
return -1; // TODO
73-
}
7490
}

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ final class UnsafeShuffleExternalSorter {
5757

5858
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
5959

60-
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this / don't duplicate
61-
private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
60+
@VisibleForTesting
61+
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
62+
private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
6263

6364
private final int initialSize;
6465
private final int numPartitions;
@@ -88,13 +89,13 @@ final class UnsafeShuffleExternalSorter {
8889
private long freeSpaceInCurrentPage = 0;
8990

9091
public UnsafeShuffleExternalSorter(
91-
TaskMemoryManager memoryManager,
92-
ShuffleMemoryManager shuffleMemoryManager,
93-
BlockManager blockManager,
94-
TaskContext taskContext,
95-
int initialSize,
96-
int numPartitions,
97-
SparkConf conf) throws IOException {
92+
TaskMemoryManager memoryManager,
93+
ShuffleMemoryManager shuffleMemoryManager,
94+
BlockManager blockManager,
95+
TaskContext taskContext,
96+
int initialSize,
97+
int numPartitions,
98+
SparkConf conf) throws IOException {
9899
this.memoryManager = memoryManager;
99100
this.shuffleMemoryManager = shuffleMemoryManager;
100101
this.blockManager = blockManager;
@@ -140,8 +141,9 @@ private SpillInfo writeSpillFile() throws IOException {
140141

141142
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
142143
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
143-
// records in a byte array. This array only needs to be big enough to hold a single record.
144-
final byte[] arr = new byte[SER_BUFFER_SIZE];
144+
// data through a byte array. This array does not need to be large enough to hold a single
145+
// record;
146+
final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
145147

146148
// Because this output will be read during shuffle, its compression codec must be controlled by
147149
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
@@ -186,16 +188,23 @@ private SpillInfo writeSpillFile() throws IOException {
186188
}
187189

188190
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
189-
final int recordLength = PlatformDependent.UNSAFE.getInt(
190-
memoryManager.getPage(recordPointer), memoryManager.getOffsetInPage(recordPointer));
191-
PlatformDependent.copyMemory(
192-
memoryManager.getPage(recordPointer),
193-
memoryManager.getOffsetInPage(recordPointer) + 4, // skip over record length
194-
arr,
195-
PlatformDependent.BYTE_ARRAY_OFFSET,
196-
recordLength);
197-
assert (writer != null); // To suppress an IntelliJ warning
198-
writer.write(arr, 0, recordLength);
191+
final Object recordPage = memoryManager.getPage(recordPointer);
192+
final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
193+
int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
194+
long recordReadPosition = recordOffsetInPage + 4; // skip over record length
195+
while (dataRemaining > 0) {
196+
final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
197+
PlatformDependent.copyMemory(
198+
recordPage,
199+
recordReadPosition,
200+
writeBuffer,
201+
PlatformDependent.BYTE_ARRAY_OFFSET,
202+
toTransfer);
203+
assert (writer != null); // To suppress an IntelliJ warning
204+
writer.write(writeBuffer, 0, toTransfer);
205+
recordReadPosition += toTransfer;
206+
dataRemaining -= toTransfer;
207+
}
199208
// TODO: add a test that detects whether we leave this call out:
200209
writer.recordWritten();
201210
}

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public PackedRecordPointer newKey() {
3838

3939
@Override
4040
public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
41-
reuse.packedRecordPointer = data[pos];
41+
reuse.set(data[pos]);
4242
return reuse;
4343
}
4444

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public boolean hasNext() {
9595

9696
@Override
9797
public void loadNext() {
98-
packedRecordPointer.packedRecordPointer = sortBuffer[position];
98+
packedRecordPointer.set(sortBuffer[position]);
9999
position++;
100100
}
101101
};

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
5454

5555
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
5656

57-
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
57+
@VisibleForTesting
58+
static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes
5859
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
5960

6061
private final BlockManager blockManager;
@@ -108,19 +109,26 @@ public UnsafeShuffleWriter(
108109
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
109110
}
110111

111-
public void write(Iterator<Product2<K, V>> records) {
112+
public void write(Iterator<Product2<K, V>> records) throws IOException {
112113
write(JavaConversions.asScalaIterator(records));
113114
}
114115

115116
@Override
116-
public void write(scala.collection.Iterator<Product2<K, V>> records) {
117+
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
117118
try {
118119
while (records.hasNext()) {
119120
insertRecordIntoSorter(records.next());
120121
}
121122
closeAndWriteOutput();
122123
} catch (Exception e) {
123-
PlatformDependent.throwException(e);
124+
// Unfortunately, we have to catch Exception here in order to ensure proper cleanup after
125+
// errors becuase Spark's Scala code, or users' custom Serializers, might throw arbitrary
126+
// unchecked exceptions.
127+
try {
128+
sorter.cleanupAfterError();
129+
} finally {
130+
throw new IOException("Error during shuffle write", e);
131+
}
124132
}
125133
}
126134

@@ -134,7 +142,7 @@ private void open() throws IOException {
134142
4096, // Initial size (TODO: tune this!)
135143
partitioner.numPartitions(),
136144
sparkConf);
137-
serArray = new byte[SER_BUFFER_SIZE];
145+
serArray = new byte[MAXIMUM_RECORD_SIZE];
138146
serByteBuffer = ByteBuffer.wrap(serArray);
139147
serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
140148
}

core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717

1818
package org.apache.spark.shuffle
1919

20+
import java.io.IOException
21+
2022
import org.apache.spark.scheduler.MapStatus
2123

2224
/**
2325
* Obtained inside a map task to write out records to the shuffle system.
2426
*/
2527
private[spark] abstract class ShuffleWriter[K, V] {
2628
/** Write a sequence of records to this task's output */
29+
@throws[IOException]
2730
def write(records: Iterator[Product2[K, V]]): Unit
2831

2932
/** Close this writer, passing along whether the map completed */

core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ public void heap() {
3434
final MemoryBlock page0 = memoryManager.allocatePage(100);
3535
final MemoryBlock page1 = memoryManager.allocatePage(100);
3636
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42);
37-
PackedRecordPointer packedPointerWrapper = new PackedRecordPointer();
38-
packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360);
39-
Assert.assertEquals(360, packedPointerWrapper.getPartitionId());
40-
Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer());
37+
PackedRecordPointer packedPointer = new PackedRecordPointer();
38+
packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
39+
Assert.assertEquals(360, packedPointer.getPartitionId());
40+
Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer());
4141
memoryManager.cleanUpAllAllocatedMemory();
4242
}
4343

@@ -48,10 +48,10 @@ public void offHeap() {
4848
final MemoryBlock page0 = memoryManager.allocatePage(100);
4949
final MemoryBlock page1 = memoryManager.allocatePage(100);
5050
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42);
51-
PackedRecordPointer packedPointerWrapper = new PackedRecordPointer();
52-
packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360);
53-
Assert.assertEquals(360, packedPointerWrapper.getPartitionId());
54-
Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer());
51+
PackedRecordPointer packedPointer = new PackedRecordPointer();
52+
packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
53+
Assert.assertEquals(360, packedPointer.getPartitionId());
54+
Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer());
5555
memoryManager.cleanUpAllAllocatedMemory();
5656
}
5757
}

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.shuffle.unsafe;
1919

2020
import java.io.*;
21+
import java.nio.ByteBuffer;
2122
import java.util.*;
2223

2324
import scala.*;
@@ -287,6 +288,42 @@ public void mergeSpillsWithFileStream() throws Exception {
287288
testMergingSpills(false);
288289
}
289290

291+
@Test
292+
public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
293+
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
294+
final ArrayList<Product2<Object, Object>> dataToWrite =
295+
new ArrayList<Product2<Object, Object>>();
296+
final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
297+
new Random(42).nextBytes(bytes);
298+
dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
299+
writer.write(dataToWrite.iterator());
300+
writer.stop(true);
301+
Assert.assertEquals(
302+
HashMultiset.create(dataToWrite),
303+
HashMultiset.create(readRecordsFromFile()));
304+
assertSpillFilesWereCleanedUp();
305+
}
306+
307+
@Test
308+
public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception {
309+
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
310+
final ArrayList<Product2<Object, Object>> dataToWrite =
311+
new ArrayList<Product2<Object, Object>>();
312+
final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2];
313+
new Random(42).nextBytes(bytes);
314+
dataToWrite.add(new Tuple2<Object, Object>(1, bytes));
315+
try {
316+
// Insert a record and force a spill so that there's something to clean up:
317+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
318+
writer.forceSorterToSpill();
319+
writer.write(dataToWrite.iterator());
320+
Assert.fail("Expected exception to be thrown");
321+
} catch (IOException e) {
322+
// Pass
323+
}
324+
assertSpillFilesWereCleanedUp();
325+
}
326+
290327
@Test
291328
public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
292329
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);

0 commit comments

Comments
 (0)