Skip to content

Commit 4f70141

Browse files
committed
Fix merging; now passes UnsafeShuffleSuite tests.
1 parent 133c8c9 commit 4f70141

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ private SpillInfo writeSpillFile() throws IOException {
118118
final File file = spilledFileInfo._2();
119119
final BlockId blockId = spilledFileInfo._1();
120120
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
121-
spills.add(spillInfo);
122121

123122
final SerializerInstance ser = new DummySerializerInstance();
124123
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics);
@@ -154,7 +153,11 @@ private SpillInfo writeSpillFile() throws IOException {
154153

155154
if (writer != null) {
156155
writer.commitAndClose();
157-
spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
156+
// TODO: comment and explain why our handling of empty spills, etc.
157+
if (currentPartition != -1) {
158+
spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
159+
spills.add(spillInfo);
160+
}
158161
}
159162
return spillInfo;
160163
}

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,14 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
157157
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
158158
final int numPartitions = partitioner.numPartitions();
159159
final long[] partitionLengths = new long[numPartitions];
160+
161+
if (spills.length == 0) {
162+
new FileOutputStream(outputFile).close();
163+
return partitionLengths;
164+
}
165+
160166
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
167+
final long[] spillInputChannelPositions = new long[spills.length];
161168

162169
// TODO: We need to add an option to bypass transferTo here since older Linux kernels are
163170
// affected by a bug here that can lead to data truncation; see the comments Utils.scala,
@@ -173,24 +180,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
173180

174181
final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel();
175182

176-
for (int partition = 0; partition < numPartitions; partition++ ) {
183+
for (int partition = 0; partition < numPartitions; partition++) {
177184
for (int i = 0; i < spills.length; i++) {
178-
final long bytesToTransfer = spills[i].partitionLengths[partition];
179-
long bytesRemainingToBeTransferred = bytesToTransfer;
185+
System.out.println("In partition " + partition + " and spill " + i );
186+
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
187+
System.out.println("Partition length in spill is " + partitionLengthInSpill);
188+
System.out.println("input channel position is " + spillInputChannels[i].position());
189+
long bytesRemainingToBeTransferred = partitionLengthInSpill;
180190
final FileChannel spillInputChannel = spillInputChannels[i];
181-
long fromPosition = spillInputChannel.position();
182191
while (bytesRemainingToBeTransferred > 0) {
183-
bytesRemainingToBeTransferred -= spillInputChannel.transferTo(
184-
fromPosition,
192+
final long actualBytesTransferred = spillInputChannel.transferTo(
193+
spillInputChannelPositions[i],
185194
bytesRemainingToBeTransferred,
186195
mergedFileOutputChannel);
196+
spillInputChannelPositions[i] += actualBytesTransferred;
197+
bytesRemainingToBeTransferred -= actualBytesTransferred;
187198
}
188-
partitionLengths[partition] += bytesToTransfer;
199+
partitionLengths[partition] += partitionLengthInSpill;
189200
}
190201
}
191202

192203
// TODO: should this be in a finally block?
193204
for (int i = 0; i < spills.length; i++) {
205+
assert(spillInputChannelPositions[i] == spills[i].file.length());
194206
spillInputChannels[i].close();
195207
}
196208
mergedFileOutputChannel.close();

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,19 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
158158

159159

160160
writer.write(numbersToSort.iterator());
161-
final MapStatus mapStatus = writer.stop(true).get();
161+
final Option<MapStatus> mapStatus = writer.stop(true);
162+
Assert.assertTrue(mapStatus.isDefined());
162163

163164
long sumOfPartitionSizes = 0;
164165
for (long size: partitionSizes) {
165166
sumOfPartitionSizes += size;
166167
}
167168
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
168169

170+
// TODO: actually try to read the shuffle output?
171+
172+
// TODO: add a test that manually triggers spills in order to exercise the merging.
173+
169174
// TODO: test that the temporary spill files were cleaned up after the merge.
170175
}
171176

0 commit comments

Comments
 (0)