Skip to content

Commit 0748458

Browse files
committed
Port UnsafeShuffleWriter to Java.
1 parent 87e721b commit 0748458

File tree

2 files changed

+282
-243
lines changed

2 files changed

+282
-243
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.unsafe;
19+
20+
import scala.Option;
21+
import scala.Product2;
22+
import scala.reflect.ClassTag;
23+
import scala.reflect.ClassTag$;
24+
25+
import java.io.File;
26+
import java.io.IOException;
27+
import java.nio.ByteBuffer;
28+
import java.util.Iterator;
29+
import java.util.LinkedList;
30+
31+
import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
32+
33+
import org.apache.spark.Partitioner;
34+
import org.apache.spark.ShuffleDependency;
35+
import org.apache.spark.SparkEnv;
36+
import org.apache.spark.TaskContext;
37+
import org.apache.spark.executor.ShuffleWriteMetrics;
38+
import org.apache.spark.scheduler.MapStatus;
39+
import org.apache.spark.scheduler.MapStatus$;
40+
import org.apache.spark.serializer.SerializationStream;
41+
import org.apache.spark.serializer.Serializer;
42+
import org.apache.spark.serializer.SerializerInstance;
43+
import org.apache.spark.shuffle.IndexShuffleBlockManager;
44+
import org.apache.spark.shuffle.ShuffleWriter;
45+
import org.apache.spark.storage.BlockManager;
46+
import org.apache.spark.storage.BlockObjectWriter;
47+
import org.apache.spark.storage.ShuffleBlockId;
48+
import org.apache.spark.unsafe.PlatformDependent;
49+
import org.apache.spark.unsafe.memory.MemoryBlock;
50+
import org.apache.spark.unsafe.memory.TaskMemoryManager;
51+
import org.apache.spark.unsafe.sort.UnsafeSorter;
52+
import static org.apache.spark.unsafe.sort.UnsafeSorter.*;
53+
54+
// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles
55+
public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
56+
57+
private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
58+
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
59+
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
60+
61+
private final IndexShuffleBlockManager shuffleBlockManager;
62+
private final BlockManager blockManager = SparkEnv.get().blockManager();
63+
private final int shuffleId;
64+
private final int mapId;
65+
private final TaskMemoryManager memoryManager;
66+
private final SerializerInstance serializer;
67+
private final Partitioner partitioner;
68+
private final ShuffleWriteMetrics writeMetrics;
69+
private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
70+
private final int fileBufferSize;
71+
private MapStatus mapStatus = null;
72+
73+
private MemoryBlock currentPage = null;
74+
private long currentPagePosition = PAGE_SIZE;
75+
76+
/**
77+
* Are we in the process of stopping? Because map tasks can call stop() with success = true
78+
* and then call stop() with success = false if they get an exception, we want to make sure
79+
* we don't try deleting files, etc twice.
80+
*/
81+
private boolean stopping = false;
82+
83+
public UnsafeShuffleWriter(
84+
IndexShuffleBlockManager shuffleBlockManager,
85+
UnsafeShuffleHandle<K, V> handle,
86+
int mapId,
87+
TaskContext context) {
88+
this.shuffleBlockManager = shuffleBlockManager;
89+
this.mapId = mapId;
90+
this.memoryManager = context.taskMemoryManager();
91+
final ShuffleDependency<K, V, V> dep = handle.dependency();
92+
this.shuffleId = dep.shuffleId();
93+
this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
94+
this.partitioner = dep.partitioner();
95+
this.writeMetrics = new ShuffleWriteMetrics();
96+
context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
97+
this.fileBufferSize =
98+
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
99+
(int) SparkEnv.get().conf().getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
100+
}
101+
102+
public void write(scala.collection.Iterator<Product2<K, V>> records) {
103+
try {
104+
final long[] partitionLengths = writeSortedRecordsToFile(sortRecords(records));
105+
shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths);
106+
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
107+
} catch (Exception e) {
108+
PlatformDependent.throwException(e);
109+
}
110+
}
111+
112+
private void ensureSpaceInDataPage(long requiredSpace) throws Exception {
113+
if (requiredSpace > PAGE_SIZE) {
114+
// TODO: throw a more specific exception?
115+
throw new Exception("Required space " + requiredSpace + " is greater than page size (" +
116+
PAGE_SIZE + ")");
117+
} else if (requiredSpace > (PAGE_SIZE - currentPagePosition)) {
118+
currentPage = memoryManager.allocatePage(PAGE_SIZE);
119+
currentPagePosition = currentPage.getBaseOffset();
120+
allocatedPages.add(currentPage);
121+
}
122+
}
123+
124+
private void freeMemory() {
125+
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
126+
while (iter.hasNext()) {
127+
memoryManager.freePage(iter.next());
128+
iter.remove();
129+
}
130+
}
131+
132+
private Iterator<RecordPointerAndKeyPrefix> sortRecords(
133+
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
134+
final UnsafeSorter sorter = new UnsafeSorter(
135+
memoryManager,
136+
RECORD_COMPARATOR,
137+
PREFIX_COMPUTER,
138+
PREFIX_COMPARATOR,
139+
4096 // Initial size (TODO: tune this!)
140+
);
141+
142+
final byte[] serArray = new byte[SER_BUFFER_SIZE];
143+
final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray);
144+
// TODO: we should not depend on this class from Kryo; copy its source or find an alternative
145+
final SerializationStream serOutputStream =
146+
serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
147+
148+
while (records.hasNext()) {
149+
final Product2<K, V> record = records.next();
150+
final K key = record._1();
151+
final int partitionId = partitioner.getPartition(key);
152+
serByteBuffer.position(0);
153+
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
154+
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
155+
serOutputStream.flush();
156+
157+
final int serializedRecordSize = serByteBuffer.position();
158+
assert (serializedRecordSize > 0);
159+
// TODO: we should run the partition extraction function _now_, at insert time, rather than
160+
// requiring it to be stored alongisde the data, since this may lead to double storage
161+
// Need 8 bytes to store the prefix (for later retrieval in the prefix computer), plus
162+
// 8 to store the record length (TODO: can store as an int instead).
163+
ensureSpaceInDataPage(serializedRecordSize + 8 + 8);
164+
165+
final long recordAddress =
166+
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
167+
final Object baseObject = currentPage.getBaseObject();
168+
PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, partitionId);
169+
currentPagePosition += 8;
170+
PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, serializedRecordSize);
171+
currentPagePosition += 8;
172+
PlatformDependent.copyMemory(
173+
serArray,
174+
PlatformDependent.BYTE_ARRAY_OFFSET,
175+
baseObject,
176+
currentPagePosition,
177+
serializedRecordSize);
178+
currentPagePosition += serializedRecordSize;
179+
180+
sorter.insertRecord(recordAddress);
181+
}
182+
183+
return sorter.getSortedIterator();
184+
}
185+
186+
private long[] writeSortedRecordsToFile(
187+
Iterator<RecordPointerAndKeyPrefix> sortedRecords) throws IOException {
188+
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
189+
final ShuffleBlockId blockId =
190+
new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID());
191+
final long[] partitionLengths = new long[partitioner.numPartitions()];
192+
193+
int currentPartition = -1;
194+
BlockObjectWriter writer = null;
195+
196+
while (sortedRecords.hasNext()) {
197+
final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next();
198+
final int partition = (int) recordPointer.keyPrefix;
199+
assert (partition >= currentPartition);
200+
if (partition != currentPartition) {
201+
// Switch to the new partition
202+
if (currentPartition != -1) {
203+
writer.commitAndClose();
204+
partitionLengths[currentPartition] = writer.fileSegment().length();
205+
}
206+
currentPartition = partition;
207+
writer =
208+
blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics);
209+
}
210+
211+
final Object baseObject = memoryManager.getPage(recordPointer.recordPointer);
212+
final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer);
213+
final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8);
214+
// TODO: re-use a buffer or avoid double-buffering entirely
215+
final byte[] arr = new byte[recordLength];
216+
PlatformDependent.copyMemory(
217+
baseObject,
218+
baseOffset + 16,
219+
arr,
220+
PlatformDependent.BYTE_ARRAY_OFFSET,
221+
recordLength);
222+
assert (writer != null); // To suppress an IntelliJ warning
223+
writer.write(arr);
224+
// TODO: add a test that detects whether we leave this call out:
225+
writer.recordWritten();
226+
}
227+
228+
if (writer != null) {
229+
writer.commitAndClose();
230+
partitionLengths[currentPartition] = writer.fileSegment().length();
231+
}
232+
233+
return partitionLengths;
234+
}
235+
236+
@Override
237+
public Option<MapStatus> stop(boolean success) {
238+
try {
239+
if (stopping) {
240+
return Option.apply(null);
241+
} else {
242+
stopping = true;
243+
freeMemory();
244+
if (success) {
245+
return Option.apply(mapStatus);
246+
} else {
247+
// The map task failed, so delete our output data.
248+
shuffleBlockManager.removeDataByMap(shuffleId, mapId);
249+
return Option.apply(null);
250+
}
251+
}
252+
} finally {
253+
freeMemory();
254+
// TODO: increment the shuffle write time metrics
255+
}
256+
}
257+
258+
private static final RecordComparator RECORD_COMPARATOR = new RecordComparator() {
259+
@Override
260+
public int compare(
261+
Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) {
262+
return 0;
263+
}
264+
};
265+
266+
private static final PrefixComputer PREFIX_COMPUTER = new PrefixComputer() {
267+
@Override
268+
public long computePrefix(Object baseObject, long baseOffset) {
269+
// TODO: should the prefix be computed when inserting the record pointer rather than being
270+
// read from the record itself? May be more efficient in terms of space, etc, and is a simple
271+
// change.
272+
return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
273+
}
274+
};
275+
276+
private static final PrefixComparator PREFIX_COMPARATOR = new PrefixComparator() {
277+
@Override
278+
public int compare(long prefix1, long prefix2) {
279+
return (int) (prefix1 - prefix2);
280+
}
281+
};
282+
}

0 commit comments

Comments
 (0)