[SPARK-30290][CORE] Count for merged block when fetching continuous blocks in batch

### What changes were proposed in this pull request?

We added shuffle block fetch optimization in SPARK-9853. In ShuffleBlockFetcherIterator, we merge single blocks into batch blocks. During merging, we should count merged blocks for `maxBlocksInFlightPerAddress`, not original single blocks.

### Why are the changes needed?

If `maxBlocksInFlightPerAddress` is specified, like set it to 1, it should mean one batch block, not one original single block. Otherwise, it will conflict with batch shuffle fetch.

### Does this PR introduce any user-facing change?

No

### How was this patch tested?

Unit test.

Closes #26930 from viirya/respect-max-blocks-inflight.

Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Co-authored-by: Liang-Chi Hsieh <liangchi@uber.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Liang-Chi Hsieh 2019-12-25 18:57:02 +08:00 committed by Wenchen Fan
parent 8f07839e74
commit 0042ad575a
3 changed files with 63 additions and 5 deletions

View file

@ -148,6 +148,8 @@ public class OneForOneBlockFetcher {
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
private String[] splitBlockId(String blockId) {
String[] blockIdParts = blockId.split("_");
// For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
// For single block id, the format contains shuffleId, mapId, educeId.
if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);

View file

@ -337,14 +337,25 @@ final class ShuffleBlockFetcherIterator(
assertPositiveBlockSize(blockId, size)
curBlocks += FetchBlockInfo(blockId, size, mapIndex)
curRequestSize += size
if (curRequestSize >= targetRemoteRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
// For batch fetch, the actual block in flight should count for merged block.
val exceedsMaxBlocksInFlightPerAddress = !doBatchFetch &&
curBlocks.size >= maxBlocksInFlightPerAddress
if (curRequestSize >= targetRemoteRequestSize || exceedsMaxBlocksInFlightPerAddress) {
// Add this FetchRequest
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
collectedRemoteRequests += new FetchRequest(address, mergedBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${mergedBlocks.size} blocks")
.grouped(maxBlocksInFlightPerAddress)
curBlocks = new ArrayBuffer[FetchBlockInfo]
mergedBlocks.foreach { mergedBlock =>
if (mergedBlock.size == maxBlocksInFlightPerAddress) {
collectedRemoteRequests += new FetchRequest(address, mergedBlock)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${mergedBlock.size} blocks")
} else {
// The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
// to `curBlocks`.
curBlocks = mergedBlock
}
}
curRequestSize = 0
}
}

View file

@ -341,6 +341,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1)
}
test("fetch continuous blocks in batch respects maxBlocksInFlightPerAddress") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-local-host", 1)
doReturn(localBmId).when(blockManager).blockManagerId
// Make sure remote blocks would return the merged block
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val remoteBlocks = Seq[BlockId](
ShuffleBlockId(0, 3, 0),
ShuffleBlockId(0, 3, 1),
ShuffleBlockId(0, 3, 2),
ShuffleBlockId(0, 3, 3))
val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockBatchId(0, 3, 0, 4) -> createMockManagedBuffer())
val transfer = createMockTransfer(mergedRemoteBlocks)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
(remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1)))
).toIterator
val taskContext = TaskContext.empty()
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
1,
Int.MaxValue,
true,
false,
metrics,
true)
assert(iterator.hasNext)
val (blockId, inputStream) = iterator.next()
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any())
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = mergedRemoteBlocks(blockId)
verifyBufferRelease(mockBuf, inputStream)
}
test("release current unexhausted buffer in case the task completes early") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)