[SPARK-23524] Big local shuffle blocks should not be checked for corruption.
## What changes were proposed in this pull request? In current code, all local blocks will be checked for corruption no matter it's big or not. The reasons are as below: Size in FetchResult for local block is set to be 0 (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L327) SPARK-4105 meant to only check the small blocks(size<maxBytesInFlight/3), but for reason 1, below check will be invalid. https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L420 We can fix this and avoid the OOM. ## How was this patch tested? UT added Author: jx158167 <jx158167@antfin.com> Closes #20685 from jinxing64/SPARK-23524.
This commit is contained in:
parent
ac76eff6a8
commit
77c91cc746
|
@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator(
|
|||
private[this] val startTime = System.currentTimeMillis
|
||||
|
||||
/** Local blocks to fetch, excluding zero-sized blocks. */
|
||||
private[this] val localBlocks = new ArrayBuffer[BlockId]()
|
||||
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
|
||||
|
||||
/** Remote blocks to fetch, excluding zero-sized blocks. */
|
||||
private[this] val remoteBlocks = new HashSet[BlockId]()
|
||||
|
@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator(
|
|||
* track in-memory are the ManagedBuffer references themselves.
|
||||
*/
|
||||
private[this] def fetchLocalBlocks() {
|
||||
logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
|
||||
val iter = localBlocks.iterator
|
||||
while (iter.hasNext) {
|
||||
val blockId = iter.next()
|
||||
|
@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator(
|
|||
shuffleMetrics.incLocalBlocksFetched(1)
|
||||
shuffleMetrics.incLocalBytesRead(buf.size)
|
||||
buf.retain()
|
||||
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
|
||||
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
|
||||
buf.size(), buf, false))
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
// If we see an exception, stop immediately.
|
||||
|
@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator(
|
|||
}
|
||||
shuffleMetrics.incRemoteBlocksFetched(1)
|
||||
}
|
||||
bytesInFlight -= size
|
||||
if (!localBlocks.contains(blockId)) {
|
||||
bytesInFlight -= size
|
||||
}
|
||||
if (isNetworkReqDone) {
|
||||
reqsInFlight -= 1
|
||||
logDebug("Number of requests in flight " + reqsInFlight)
|
||||
|
@ -583,8 +587,8 @@ object ShuffleBlockFetcherIterator {
|
|||
* Result of a fetch from a remote block successfully.
|
||||
* @param blockId block id
|
||||
* @param address BlockManager that the block was fetched from.
|
||||
* @param size estimated size of the block, used to calculate bytesInFlight.
|
||||
* Note that this is NOT the exact bytes.
|
||||
* @param size estimated size of the block. Note that this is NOT the exact bytes.
|
||||
* Size of remote block is used to calculate bytesInFlight.
|
||||
* @param buf `ManagedBuffer` for the content.
|
||||
* @param isNetworkReqDone Is this the last network request for this host in this fetch request.
|
||||
*/
|
||||
|
|
|
@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
intercept[FetchFailedException] { iterator.next() }
|
||||
}
|
||||
|
||||
test("big blocks are not checked for corruption") {
|
||||
val corruptStream = mock(classOf[InputStream])
|
||||
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
|
||||
val corruptBuffer = mock(classOf[ManagedBuffer])
|
||||
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
|
||||
doReturn(10000L).when(corruptBuffer).size()
|
||||
|
||||
val blockManager = mock(classOf[BlockManager])
|
||||
val localBmId = BlockManagerId("test-client", "test-client", 1)
|
||||
doReturn(localBmId).when(blockManager).blockManagerId
|
||||
doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0))
|
||||
val localBlockLengths = Seq[Tuple2[BlockId, Long]](
|
||||
ShuffleBlockId(0, 0, 0) -> corruptBuffer.size()
|
||||
)
|
||||
|
||||
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
|
||||
val remoteBlockLengths = Seq[Tuple2[BlockId, Long]](
|
||||
ShuffleBlockId(0, 1, 0) -> corruptBuffer.size()
|
||||
)
|
||||
|
||||
val transfer = createMockTransfer(
|
||||
Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer))
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
|
||||
(localBmId, localBlockLengths),
|
||||
(remoteBmId, remoteBlockLengths)
|
||||
)
|
||||
|
||||
val taskContext = TaskContext.empty()
|
||||
val iterator = new ShuffleBlockFetcherIterator(
|
||||
taskContext,
|
||||
transfer,
|
||||
blockManager,
|
||||
blocksByAddress,
|
||||
(_, in) => new LimitedInputStream(in, 10000),
|
||||
2048,
|
||||
Int.MaxValue,
|
||||
Int.MaxValue,
|
||||
Int.MaxValue,
|
||||
true)
|
||||
// Blocks should be returned without exceptions.
|
||||
assert(Set(iterator.next()._1, iterator.next()._1) ===
|
||||
Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0)))
|
||||
}
|
||||
|
||||
test("retry corrupt blocks (disabled)") {
|
||||
val blockManager = mock(classOf[BlockManager])
|
||||
val localBmId = BlockManagerId("test-client", "test-client", 1)
|
||||
|
|
Loading…
Reference in a new issue