[SPARK-30290][CORE] Count for merged block when fetching continuous blocks in batch
### What changes were proposed in this pull request? We added shuffle block fetch optimization in SPARK-9853. In ShuffleBlockFetcherIterator, we merge single blocks into batch blocks. During merging, we should count merged blocks for `maxBlocksInFlightPerAddress`, not original single blocks. ### Why are the changes needed? If `maxBlocksInFlightPerAddress` is specified, like set it to 1, it should mean one batch block, not one original single block. Otherwise, it will conflict with batch shuffle fetch. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Unit test. Closes #26930 from viirya/respect-max-blocks-inflight. Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com> Co-authored-by: Liang-Chi Hsieh <liangchi@uber.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
8f07839e74
commit
0042ad575a
|
@ -148,6 +148,8 @@ public class OneForOneBlockFetcher {
|
|||
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
|
||||
private String[] splitBlockId(String blockId) {
|
||||
String[] blockIdParts = blockId.split("_");
|
||||
// For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
|
||||
// For single block id, the format contains shuffleId, mapId, educeId.
|
||||
if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
|
||||
throw new IllegalArgumentException(
|
||||
"Unexpected shuffle block id format: " + blockId);
|
||||
|
|
|
@ -337,14 +337,25 @@ final class ShuffleBlockFetcherIterator(
|
|||
assertPositiveBlockSize(blockId, size)
|
||||
curBlocks += FetchBlockInfo(blockId, size, mapIndex)
|
||||
curRequestSize += size
|
||||
if (curRequestSize >= targetRemoteRequestSize ||
|
||||
curBlocks.size >= maxBlocksInFlightPerAddress) {
|
||||
// For batch fetch, the actual block in flight should count for merged block.
|
||||
val exceedsMaxBlocksInFlightPerAddress = !doBatchFetch &&
|
||||
curBlocks.size >= maxBlocksInFlightPerAddress
|
||||
if (curRequestSize >= targetRemoteRequestSize || exceedsMaxBlocksInFlightPerAddress) {
|
||||
// Add this FetchRequest
|
||||
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
|
||||
collectedRemoteRequests += new FetchRequest(address, mergedBlocks)
|
||||
logDebug(s"Creating fetch request of $curRequestSize at $address "
|
||||
+ s"with ${mergedBlocks.size} blocks")
|
||||
.grouped(maxBlocksInFlightPerAddress)
|
||||
curBlocks = new ArrayBuffer[FetchBlockInfo]
|
||||
mergedBlocks.foreach { mergedBlock =>
|
||||
if (mergedBlock.size == maxBlocksInFlightPerAddress) {
|
||||
collectedRemoteRequests += new FetchRequest(address, mergedBlock)
|
||||
logDebug(s"Creating fetch request of $curRequestSize at $address "
|
||||
+ s"with ${mergedBlock.size} blocks")
|
||||
} else {
|
||||
// The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
|
||||
// to `curBlocks`.
|
||||
curBlocks = mergedBlock
|
||||
}
|
||||
}
|
||||
curRequestSize = 0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -341,6 +341,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1)
|
||||
}
|
||||
|
||||
test("fetch continuous blocks in batch respects maxBlocksInFlightPerAddress") {
|
||||
val blockManager = mock(classOf[BlockManager])
|
||||
val localBmId = BlockManagerId("test-client", "test-local-host", 1)
|
||||
doReturn(localBmId).when(blockManager).blockManagerId
|
||||
|
||||
// Make sure remote blocks would return the merged block
|
||||
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
|
||||
val remoteBlocks = Seq[BlockId](
|
||||
ShuffleBlockId(0, 3, 0),
|
||||
ShuffleBlockId(0, 3, 1),
|
||||
ShuffleBlockId(0, 3, 2),
|
||||
ShuffleBlockId(0, 3, 3))
|
||||
val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
|
||||
ShuffleBlockBatchId(0, 3, 0, 4) -> createMockManagedBuffer())
|
||||
val transfer = createMockTransfer(mergedRemoteBlocks)
|
||||
|
||||
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
|
||||
(remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1)))
|
||||
).toIterator
|
||||
|
||||
val taskContext = TaskContext.empty()
|
||||
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
|
||||
val iterator = new ShuffleBlockFetcherIterator(
|
||||
taskContext,
|
||||
transfer,
|
||||
blockManager,
|
||||
blocksByAddress,
|
||||
(_, in) => in,
|
||||
48 * 1024 * 1024,
|
||||
Int.MaxValue,
|
||||
1,
|
||||
Int.MaxValue,
|
||||
true,
|
||||
false,
|
||||
metrics,
|
||||
true)
|
||||
|
||||
assert(iterator.hasNext)
|
||||
val (blockId, inputStream) = iterator.next()
|
||||
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any())
|
||||
// Make sure we release buffers when a wrapped input stream is closed.
|
||||
val mockBuf = mergedRemoteBlocks(blockId)
|
||||
verifyBufferRelease(mockBuf, inputStream)
|
||||
}
|
||||
|
||||
test("release current unexhausted buffer in case the task completes early") {
|
||||
val blockManager = mock(classOf[BlockManager])
|
||||
val localBmId = BlockManagerId("test-client", "test-client", 1)
|
||||
|
|
Loading…
Reference in a new issue