1717
1818package org .apache .spark .storage
1919
20- import java .io .InputStream
20+ import java .io .{InputStream , IOException }
21+ import java .nio .ByteBuffer
2122import java .util .concurrent .LinkedBlockingQueue
2223import javax .annotation .concurrent .GuardedBy
2324
25+ import scala .collection .mutable
2426import scala .collection .mutable .{ArrayBuffer , HashSet , Queue }
2527import scala .util .control .NonFatal
2628
2729import org .apache .spark .{SparkException , TaskContext }
2830import org .apache .spark .internal .Logging
29- import org .apache .spark .network .buffer .ManagedBuffer
31+ import org .apache .spark .network .buffer .{ FileSegmentManagedBuffer , ManagedBuffer }
3032import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
3133import org .apache .spark .shuffle .FetchFailedException
3234import org .apache .spark .util .Utils
35+ import org .apache .spark .util .io .{ChunkedByteBufferInputStream , ChunkedByteBufferOutputStream }
3336
3437/**
3538 * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -56,6 +59,7 @@ final class ShuffleBlockFetcherIterator(
5659 shuffleClient : ShuffleClient ,
5760 blockManager : BlockManager ,
5861 blocksByAddress : Seq [(BlockManagerId , Seq [(BlockId , Long )])],
62+ streamWrapper : (BlockId , InputStream ) => InputStream ,
5963 maxBytesInFlight : Long ,
6064 maxReqsInFlight : Int )
6165 extends Iterator [(BlockId , InputStream )] with Logging {
@@ -108,6 +112,9 @@ final class ShuffleBlockFetcherIterator(
108112 /** Current number of requests in flight */
109113 private [this ] var reqsInFlight = 0
110114
115+ /** The blocks that can't be decompressed successfully */
116+ private [this ] val corruptedBlocks = mutable.HashSet [String ]()
117+
111118 private [this ] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
112119
113120 /**
@@ -305,35 +312,82 @@ final class ShuffleBlockFetcherIterator(
305312 */
306313 override def next (): (BlockId , InputStream ) = {
307314 numBlocksProcessed += 1
308- val startFetchWait = System .currentTimeMillis()
309- currentResult = results.take()
310- val result = currentResult
311- val stopFetchWait = System .currentTimeMillis()
312- shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
313315
314- result match {
315- case SuccessFetchResult (_, address, size, buf, isNetworkReqDone) =>
316- if (address != blockManager.blockManagerId) {
317- shuffleMetrics.incRemoteBytesRead(buf.size)
318- shuffleMetrics.incRemoteBlocksFetched(1 )
319- }
320- bytesInFlight -= size
321- if (isNetworkReqDone) {
322- reqsInFlight -= 1
323- logDebug(" Number of requests in flight " + reqsInFlight)
324- }
325- case _ =>
316+ var result : FetchResult = null
317+ var input : InputStream = null
318+ while (result == null ) {
319+ val startFetchWait = System .currentTimeMillis()
320+ result = results.take()
321+ val stopFetchWait = System .currentTimeMillis()
322+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
323+
324+ result match {
325+ case SuccessFetchResult (blockId, address, size, buf, isNetworkReqDone) =>
326+ if (address != blockManager.blockManagerId) {
327+ shuffleMetrics.incRemoteBytesRead(buf.size)
328+ shuffleMetrics.incRemoteBlocksFetched(1 )
329+ }
330+ bytesInFlight -= size
331+ if (isNetworkReqDone) {
332+ reqsInFlight -= 1
333+ logDebug(" Number of requests in flight " + reqsInFlight)
334+ }
335+
336+ val in = try {
337+ buf.createInputStream()
338+ } catch {
339+ // The exception could only be throwed by local shuffle block
340+ case e : IOException if buf.isInstanceOf [FileSegmentManagedBuffer ] =>
341+ logError(" Failed to create input stream from local block" , e)
342+ buf.release()
343+ result = FailureFetchResult (blockId, address, e)
344+ null
345+ }
346+ if (in != null ) {
347+ input = streamWrapper(blockId, in)
348+ // Only copy the stream if it's wrapped by compression or encryption, also the size of
349+ // block is small (the decompressed block is smaller than maxBytesInFlight)
350+ if (! input.eq(in) && size < maxBytesInFlight / 3 ) {
351+ val out = new ChunkedByteBufferOutputStream (64 * 1024 , ByteBuffer .allocate)
352+ try {
353+ // Decompress the whole block at once to detect any corruption, which could increase
354+ // the memory usage tne potential increase the chance of OOM.
355+ // TODO: manage the memory used here, and spill it into disk in case of OOM.
356+ Utils .copyStream(input, out)
357+ input = out.toChunkedByteBuffer.toInputStream(true )
358+ } catch {
359+ case e : IOException =>
360+ buf.release()
361+ if (buf.isInstanceOf [FileSegmentManagedBuffer ]
362+ || corruptedBlocks.contains(blockId.toString)) {
363+ result = FailureFetchResult (blockId, address, e)
364+ } else {
365+ logWarning(s " got an corrupted block $blockId from $address, fetch again " )
366+ fetchRequests += FetchRequest (address, Array ((blockId, size)))
367+ result = null
368+ }
369+ } finally {
370+ // TODO: release the buf here (earlier)
371+ in.close()
372+ }
373+ }
374+ }
375+
376+ case _ =>
377+ }
378+
379+ // Send fetch requests up to maxBytesInFlight
380+ fetchUpToMaxBytes()
326381 }
327- // Send fetch requests up to maxBytesInFlight
328- fetchUpToMaxBytes()
382+ currentResult = result
329383
330384 result match {
331385 case FailureFetchResult (blockId, address, e) =>
332386 throwFetchFailedException(blockId, address, e)
333387
334388 case SuccessFetchResult (blockId, address, _, buf, _) =>
335389 try {
336- (result.blockId, new BufferReleasingInputStream (buf.createInputStream() , this ))
390+ (result.blockId, new BufferReleasingInputStream (input , this ))
337391 } catch {
338392 case NonFatal (t) =>
339393 throwFetchFailedException(blockId, address, t)
0 commit comments