Skip to content

Commit a68171c

Browse files
author
Davies Liu
committed
checksum
1 parent 26b07f1 commit a68171c

File tree

25 files changed

+246
-73
lines changed

25 files changed

+246
-73
lines changed

common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
package org.apache.spark.network.buffer;
1919

20-
import java.io.File;
21-
import java.io.FileInputStream;
22-
import java.io.IOException;
23-
import java.io.InputStream;
24-
import java.io.RandomAccessFile;
20+
import java.io.*;
2521
import java.nio.ByteBuffer;
2622
import java.nio.channels.FileChannel;
23+
import java.util.zip.Adler32;
24+
import java.util.zip.CheckedInputStream;
25+
import java.util.zip.Checksum;
2726

2827
import com.google.common.base.Objects;
2928
import com.google.common.io.ByteStreams;
@@ -92,12 +91,27 @@ public ByteBuffer nioByteBuffer() throws IOException {
9291
}
9392

9493
@Override
95-
public InputStream createInputStream() throws IOException {
94+
public InputStream createInputStream(boolean checksum) throws IOException {
9695
FileInputStream is = null;
9796
try {
9897
is = new FileInputStream(file);
9998
ByteStreams.skipFully(is, offset);
100-
return new LimitedInputStream(is, length);
99+
if (checksum) {
100+
Checksum ck = new Adler32();
101+
DataInputStream din = new DataInputStream(new CheckedInputStream(is, ck));
102+
ByteStreams.skipFully(din, length - 8);
103+
long sum = ck.getValue();
104+
long expected = din.readLong();
105+
if (sum != expected) {
106+
throw new IOException("Checksum does not match " + sum + "!=" + expected);
107+
}
108+
is.close();
109+
is = new FileInputStream(file);
110+
ByteStreams.skipFully(is, offset);
111+
return new LimitedInputStream(is, length - 8);
112+
} else {
113+
return new LimitedInputStream(is, length);
114+
}
101115
} catch (IOException e) {
102116
try {
103117
if (is != null) {

common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public abstract class ManagedBuffer {
5151
* necessarily check for the length of bytes read, so the caller is responsible for making sure
5252
* it does not go over the limit.
5353
*/
54-
public abstract InputStream createInputStream() throws IOException;
54+
public abstract InputStream createInputStream(boolean checksum) throws IOException;
5555

5656
/**
5757
* Increment the reference count by one if applicable.

common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.IOException;
2121
import java.io.InputStream;
2222
import java.nio.ByteBuffer;
23+
import java.util.zip.Adler32;
2324

2425
import com.google.common.base.Objects;
2526
import io.netty.buffer.ByteBuf;
@@ -46,7 +47,21 @@ public ByteBuffer nioByteBuffer() throws IOException {
4647
}
4748

4849
@Override
49-
public InputStream createInputStream() throws IOException {
50+
public InputStream createInputStream(boolean checksum) throws IOException {
51+
if (checksum) {
52+
Adler32 adler = new Adler32();
53+
long size = size();
54+
buf.markReaderIndex();
55+
for (int i = 0; i < size - 8; i++) {
56+
adler.update(buf.readByte());
57+
}
58+
long sum = buf.readLong();
59+
if (adler.getValue() != sum) {
60+
throw new IOException("Checksum does not match " + adler.getValue() + "!=" + sum);
61+
}
62+
buf.resetReaderIndex();
63+
buf.writerIndex(buf.writerIndex() - 8);
64+
}
5065
return new ByteBufInputStream(buf);
5166
}
5267

common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.IOException;
2121
import java.io.InputStream;
2222
import java.nio.ByteBuffer;
23+
import java.util.zip.Adler32;
2324

2425
import com.google.common.base.Objects;
2526
import io.netty.buffer.ByteBufInputStream;
@@ -46,7 +47,23 @@ public ByteBuffer nioByteBuffer() throws IOException {
4647
}
4748

4849
@Override
49-
public InputStream createInputStream() throws IOException {
50+
public InputStream createInputStream(boolean checksum) throws IOException {
51+
if (checksum) {
52+
Adler32 adler = new Adler32();
53+
int position = buf.position();
54+
int limit = buf.limit() - 8;
55+
buf.position(limit);
56+
long sum = buf.getLong();
57+
buf.position(position);
58+
// simplify this after drop Java 7 support
59+
for (int i=buf.position(); i<limit; i++) {
60+
adler.update(buf.get(i));
61+
}
62+
if (sum != adler.getValue()) {
63+
throw new IOException("Checksum does not match: " + adler.getValue() + "!=" + sum);
64+
}
65+
buf.limit(limit);
66+
}
5067
return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
5168
}
5269

common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ public ByteBuffer nioByteBuffer() throws IOException {
5959
}
6060

6161
@Override
62-
public InputStream createInputStream() throws IOException {
63-
return underlying.createInputStream();
62+
public InputStream createInputStream(boolean checksum) throws IOException {
63+
return underlying.createInputStream(checksum);
6464
}
6565

6666
@Override

common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,17 @@
1717

1818
package org.apache.spark.network.sasl;
1919

20-
import static org.junit.Assert.*;
21-
import static org.mockito.Mockito.*;
22-
20+
import javax.security.sasl.SaslException;
2321
import java.io.File;
2422
import java.lang.reflect.Method;
2523
import java.nio.ByteBuffer;
2624
import java.util.Arrays;
2725
import java.util.List;
2826
import java.util.Random;
2927
import java.util.concurrent.CountDownLatch;
30-
import java.util.concurrent.TimeoutException;
3128
import java.util.concurrent.TimeUnit;
29+
import java.util.concurrent.TimeoutException;
3230
import java.util.concurrent.atomic.AtomicReference;
33-
import javax.security.sasl.SaslException;
3431

3532
import com.google.common.collect.Lists;
3633
import com.google.common.io.ByteStreams;
@@ -62,6 +59,9 @@
6259
import org.apache.spark.network.util.SystemPropertyConfigProvider;
6360
import org.apache.spark.network.util.TransportConf;
6461

62+
import static org.junit.Assert.*;
63+
import static org.mockito.Mockito.*;
64+
6565
/**
6666
* Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
6767
*/
@@ -296,7 +296,7 @@ public Void answer(InvocationOnMock invocation) {
296296
verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
297297
verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
298298

299-
byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
299+
byte[] received = ByteStreams.toByteArray(response.get().createInputStream(false));
300300
assertTrue(Arrays.equals(data, received));
301301
} finally {
302302
file.delete();

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@
2424

2525
import com.fasterxml.jackson.databind.ObjectMapper;
2626
import com.google.common.io.CharStreams;
27-
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
28-
import org.apache.spark.network.util.SystemPropertyConfigProvider;
29-
import org.apache.spark.network.util.TransportConf;
30-
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
3127
import org.junit.AfterClass;
3228
import org.junit.BeforeClass;
3329
import org.junit.Test;
3430

31+
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
32+
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
33+
import org.apache.spark.network.util.SystemPropertyConfigProvider;
34+
import org.apache.spark.network.util.TransportConf;
35+
3536
import static org.junit.Assert.*;
3637

3738
public class ExternalShuffleBlockResolverSuite {
@@ -98,14 +99,14 @@ public void testSortShuffleBlocks() throws IOException {
9899
dataContext.createExecutorInfo(SORT_MANAGER));
99100

100101
InputStream block0Stream =
101-
resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream();
102+
resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(false);
102103
String block0 = CharStreams.toString(
103104
new InputStreamReader(block0Stream, StandardCharsets.UTF_8));
104105
block0Stream.close();
105106
assertEquals(sortBlock0, block0);
106107

107108
InputStream block1Stream =
108-
resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream();
109+
resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(false);
109110
String block1 = CharStreams.toString(
110111
new InputStreamReader(block1Stream, StandardCharsets.UTF_8));
111112
block1Stream.close();

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
139139
final File file = tempShuffleBlockIdPlusFile._2();
140140
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
141141
partitionWriters[i] =
142-
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
142+
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics, true);
143143
}
144144
// Creating the file to write to and creating a disk writer both involve interacting with
145145
// the disk, and can take a long time in aggregate when we open many files, so should be

core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
173173
final SerializerInstance ser = DummySerializerInstance.INSTANCE;
174174

175175
final DiskBlockObjectWriter writer =
176-
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
176+
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse,
177+
// only generate checksum for only spill
178+
isLastFile && spills.isEmpty());
177179

178180
int currentPartition = -1;
179181
while (sortedRecords.hasNext()) {

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.io.*;
2222
import java.nio.channels.FileChannel;
2323
import java.util.Iterator;
24+
import java.util.zip.Adler32;
2425

2526
import scala.Option;
2627
import scala.Product2;
@@ -35,7 +36,10 @@
3536
import org.slf4j.Logger;
3637
import org.slf4j.LoggerFactory;
3738

38-
import org.apache.spark.*;
39+
import org.apache.spark.Partitioner;
40+
import org.apache.spark.ShuffleDependency;
41+
import org.apache.spark.SparkConf;
42+
import org.apache.spark.TaskContext;
3943
import org.apache.spark.annotation.Private;
4044
import org.apache.spark.executor.ShuffleWriteMetrics;
4145
import org.apache.spark.io.CompressionCodec;
@@ -49,6 +53,7 @@
4953
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
5054
import org.apache.spark.shuffle.ShuffleWriter;
5155
import org.apache.spark.storage.BlockManager;
56+
import org.apache.spark.storage.ChecksumOutputStream;
5257
import org.apache.spark.storage.TimeTrackingOutputStream;
5358
import org.apache.spark.unsafe.Platform;
5459
import org.apache.spark.util.Utils;
@@ -75,6 +80,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
7580
private final SparkConf sparkConf;
7681
private final boolean transferToEnabled;
7782
private final int initialSortBufferSize;
83+
private final boolean checksum;
7884

7985
@Nullable private MapStatus mapStatus;
8086
@Nullable private ShuffleExternalSorter sorter;
@@ -108,8 +114,8 @@ public UnsafeShuffleWriter(
108114
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
109115
throw new IllegalArgumentException(
110116
"UnsafeShuffleWriter can only be used for shuffles with at most " +
111-
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
112-
" reduce partitions");
117+
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
118+
" reduce partitions");
113119
}
114120
this.blockManager = blockManager;
115121
this.shuffleBlockResolver = shuffleBlockResolver;
@@ -124,7 +130,9 @@ public UnsafeShuffleWriter(
124130
this.sparkConf = sparkConf;
125131
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
126132
this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize",
127-
DEFAULT_INITIAL_SORT_BUFFER_SIZE);
133+
DEFAULT_INITIAL_SORT_BUFFER_SIZE);
134+
this.checksum = sparkConf.getBoolean("spark.shuffle.checksum", true);
135+
128136
open();
129137
}
130138

@@ -289,7 +297,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
289297
// Compression is disabled or we are using an IO compression codec that supports
290298
// decompression of concatenated compressed streams, so we can perform a fast spill merge
291299
// that doesn't need to interpret the spilled bytes.
292-
if (transferToEnabled) {
300+
if (transferToEnabled && !checksum) {
293301
logger.debug("Using transferTo-based fast merge");
294302
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
295303
} else {
@@ -346,8 +354,11 @@ private long[] mergeSpillsWithFileStream(
346354
}
347355
for (int partition = 0; partition < numPartitions; partition++) {
348356
final long initialFileLength = outputFile.length();
349-
mergedFileOutputStream =
350-
new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
357+
OutputStream fos = new FileOutputStream(outputFile, true);
358+
if (checksum) {
359+
fos = new ChecksumOutputStream(fos, new Adler32());
360+
}
361+
mergedFileOutputStream = new TimeTrackingOutputStream(writeMetrics, fos);
351362
if (compressionCodec != null) {
352363
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
353364
}

0 commit comments

Comments
 (0)