Skip to content

Commit 5c93aaf

Browse files
author
Davies Liu
committed
retry the fetch or stage if shuffle block is corrupt
1 parent 5558998 commit 5c93aaf

File tree

4 files changed

+83
-30
lines changed

4 files changed

+83
-30
lines changed

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,20 @@ private[spark] class BlockStoreShuffleReader[K, C](
4242

4343
/** Read the combined key-values for this reduce task */
4444
override def read(): Iterator[Product2[K, C]] = {
45-
val blockFetcherItr = new ShuffleBlockFetcherIterator(
45+
val wrappedStreams = new ShuffleBlockFetcherIterator(
4646
context,
4747
blockManager.shuffleClient,
4848
blockManager,
4949
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
50+
serializerManager.wrapStream,
5051
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
5152
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
5253
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
5354

54-
// Wrap the streams for compression and encryption based on configuration
55-
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
56-
serializerManager.wrapStream(blockId, inputStream)
57-
}
58-
5955
val serializerInstance = dep.serializer.newInstance()
6056

6157
// Create a key/value iterator for each stream
62-
val recordIter = wrappedStreams.flatMap { wrappedStream =>
58+
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
6359
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
6460
// NextIterator. The NextIterator makes sure that close() is called on the
6561
// underlying InputStream when all records have been read.

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.InputStream
20+
import java.io.{InputStream, IOException}
21+
import java.nio.ByteBuffer
2122
import java.util.concurrent.LinkedBlockingQueue
2223
import javax.annotation.concurrent.GuardedBy
2324

25+
import scala.collection.mutable
2426
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
2527
import scala.util.control.NonFatal
2628

2729
import org.apache.spark.{SparkException, TaskContext}
2830
import org.apache.spark.internal.Logging
29-
import org.apache.spark.network.buffer.ManagedBuffer
31+
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3032
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
3133
import org.apache.spark.shuffle.FetchFailedException
3234
import org.apache.spark.util.Utils
35+
import org.apache.spark.util.io.{ChunkedByteBufferInputStream, ChunkedByteBufferOutputStream}
3336

3437
/**
3538
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -56,6 +59,7 @@ final class ShuffleBlockFetcherIterator(
5659
shuffleClient: ShuffleClient,
5760
blockManager: BlockManager,
5861
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
62+
streamWrapper: (BlockId, InputStream) => InputStream,
5963
maxBytesInFlight: Long,
6064
maxReqsInFlight: Int)
6165
extends Iterator[(BlockId, InputStream)] with Logging {
@@ -108,6 +112,9 @@ final class ShuffleBlockFetcherIterator(
108112
/** Current number of requests in flight */
109113
private[this] var reqsInFlight = 0
110114

115+
/** The blocks that can't be decompressed successfully */
116+
private[this] val corruptedBlocks = mutable.HashSet[String]()
117+
111118
private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
112119

113120
/**
@@ -305,35 +312,82 @@ final class ShuffleBlockFetcherIterator(
305312
*/
306313
override def next(): (BlockId, InputStream) = {
307314
numBlocksProcessed += 1
308-
val startFetchWait = System.currentTimeMillis()
309-
currentResult = results.take()
310-
val result = currentResult
311-
val stopFetchWait = System.currentTimeMillis()
312-
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
313315

314-
result match {
315-
case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
316-
if (address != blockManager.blockManagerId) {
317-
shuffleMetrics.incRemoteBytesRead(buf.size)
318-
shuffleMetrics.incRemoteBlocksFetched(1)
319-
}
320-
bytesInFlight -= size
321-
if (isNetworkReqDone) {
322-
reqsInFlight -= 1
323-
logDebug("Number of requests in flight " + reqsInFlight)
324-
}
325-
case _ =>
316+
var result: FetchResult = null
317+
var input: InputStream = null
318+
while (result == null) {
319+
val startFetchWait = System.currentTimeMillis()
320+
result = results.take()
321+
val stopFetchWait = System.currentTimeMillis()
322+
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
323+
324+
result match {
325+
case SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
326+
if (address != blockManager.blockManagerId) {
327+
shuffleMetrics.incRemoteBytesRead(buf.size)
328+
shuffleMetrics.incRemoteBlocksFetched(1)
329+
}
330+
bytesInFlight -= size
331+
if (isNetworkReqDone) {
332+
reqsInFlight -= 1
333+
logDebug("Number of requests in flight " + reqsInFlight)
334+
}
335+
336+
val in = try {
337+
buf.createInputStream()
338+
} catch {
339+
// The exception could only be throwed by local shuffle block
340+
case e: IOException if buf.isInstanceOf[FileSegmentManagedBuffer] =>
341+
logError("Failed to create input stream from local block", e)
342+
buf.release()
343+
result = FailureFetchResult(blockId, address, e)
344+
null
345+
}
346+
if (in != null) {
347+
input = streamWrapper(blockId, in)
348+
// Only copy the stream if it's wrapped by compression or encryption, also the size of
349+
// block is small (the decompressed block is smaller than maxBytesInFlight)
350+
if (!input.eq(in) && size < maxBytesInFlight / 3) {
351+
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
352+
try {
353+
// Decompress the whole block at once to detect any corruption, which could increase
354+
// the memory usage tne potential increase the chance of OOM.
355+
// TODO: manage the memory used here, and spill it into disk in case of OOM.
356+
Utils.copyStream(input, out)
357+
input = out.toChunkedByteBuffer.toInputStream(true)
358+
} catch {
359+
case e: IOException =>
360+
buf.release()
361+
if (buf.isInstanceOf[FileSegmentManagedBuffer]
362+
|| corruptedBlocks.contains(blockId.toString)) {
363+
result = FailureFetchResult(blockId, address, e)
364+
} else {
365+
logWarning(s"got an corrupted block $blockId from $address, fetch again")
366+
fetchRequests += FetchRequest(address, Array((blockId, size)))
367+
result = null
368+
}
369+
} finally {
370+
// TODO: release the buf here (earlier)
371+
in.close()
372+
}
373+
}
374+
}
375+
376+
case _ =>
377+
}
378+
379+
// Send fetch requests up to maxBytesInFlight
380+
fetchUpToMaxBytes()
326381
}
327-
// Send fetch requests up to maxBytesInFlight
328-
fetchUpToMaxBytes()
382+
currentResult = result
329383

330384
result match {
331385
case FailureFetchResult(blockId, address, e) =>
332386
throwFetchFailedException(blockId, address, e)
333387

334388
case SuccessFetchResult(blockId, address, _, buf, _) =>
335389
try {
336-
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
390+
(result.blockId, new BufferReleasingInputStream(input, this))
337391
} catch {
338392
case NonFatal(t) =>
339393
throwFetchFailedException(blockId, address, t)

core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
151151
* @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream
152152
* in order to close any memory-mapped files which back the buffer.
153153
*/
154-
private class ChunkedByteBufferInputStream(
154+
private[spark] class ChunkedByteBufferInputStream(
155155
var chunkedByteBuffer: ChunkedByteBuffer,
156156
dispose: Boolean)
157157
extends InputStream {

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
9999
transfer,
100100
blockManager,
101101
blocksByAddress,
102+
(_, in) => in,
102103
48 * 1024 * 1024,
103104
Int.MaxValue)
104105

@@ -172,6 +173,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
172173
transfer,
173174
blockManager,
174175
blocksByAddress,
176+
(_, in) => in,
175177
48 * 1024 * 1024,
176178
Int.MaxValue)
177179

@@ -235,6 +237,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
235237
transfer,
236238
blockManager,
237239
blocksByAddress,
240+
(_, in) => in,
238241
48 * 1024 * 1024,
239242
Int.MaxValue)
240243

0 commit comments

Comments
 (0)