[SPARK-30571][CORE] fix splitting shuffle fetch requests

### What changes were proposed in this pull request?

This is a followup of https://github.com/apache/spark/pull/26930 to fix a bug.

When we create shuffle fetch requests, we first collect blocks until they reach the max size. Then we try to merge the blocks (the batch shuffle fetch feature) and split the merged blocks to several groups, to make sure each group doesn't reach the max numBlocks. For the last group, if it's smaller than the max numBlocks, put it back to the input list and deal with it again later.

The last step has a problem:
1. if we put a merged block back to the input list and merge it again, it fails.
2. when putting back some blocks, we should update `numBlocksToFetch`

This PR fixes these 2 problems.

### Why are the changes needed?
bug fix

### Does this PR introduce any user-facing change?

no

### How was this patch tested?

new test

Closes #27280 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2020-01-21 14:45:50 +08:00
parent 78df532556
commit 595cdb09a4
2 changed files with 97 additions and 53 deletions

View file

@ -332,37 +332,47 @@ final class ShuffleBlockFetcherIterator(
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[FetchBlockInfo]
def createFetchRequest(blocks: Seq[FetchBlockInfo]): Unit = {
collectedRemoteRequests += FetchRequest(address, blocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${blocks.size} blocks")
}
def createFetchRequests(): Unit = {
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
curBlocks = new ArrayBuffer[FetchBlockInfo]
if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
createFetchRequest(mergedBlocks)
} else {
mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks =>
if (blocks.length == maxBlocksInFlightPerAddress) {
createFetchRequest(blocks)
} else {
// The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
// to `curBlocks`.
curBlocks = blocks
numBlocksToFetch -= blocks.size
}
}
}
curRequestSize = curBlocks.map(_.size).sum
}
while (iterator.hasNext) {
val (blockId, size, mapIndex) = iterator.next()
assertPositiveBlockSize(blockId, size)
curBlocks += FetchBlockInfo(blockId, size, mapIndex)
curRequestSize += size
// For batch fetch, the actual block in flight should count for merged block.
val exceedsMaxBlocksInFlightPerAddress = !doBatchFetch &&
curBlocks.size >= maxBlocksInFlightPerAddress
if (curRequestSize >= targetRemoteRequestSize || exceedsMaxBlocksInFlightPerAddress) {
// Add this FetchRequest
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
.grouped(maxBlocksInFlightPerAddress)
curBlocks = new ArrayBuffer[FetchBlockInfo]
mergedBlocks.foreach { mergedBlock =>
if (mergedBlock.size == maxBlocksInFlightPerAddress) {
collectedRemoteRequests += new FetchRequest(address, mergedBlock)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${mergedBlock.size} blocks")
} else {
// The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
// to `curBlocks`.
curBlocks = mergedBlock
}
}
curRequestSize = 0
val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
createFetchRequests()
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
collectedRemoteRequests += new FetchRequest(address, mergedBlocks)
createFetchRequests()
}
}
@ -380,35 +390,57 @@ final class ShuffleBlockFetcherIterator(
private[this] def mergeContinuousShuffleBlockIdsIfNeeded(
blocks: ArrayBuffer[FetchBlockInfo]): ArrayBuffer[FetchBlockInfo] = {
def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = {
val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId]
FetchBlockInfo(
ShuffleBlockBatchId(
startBlockId.shuffleId,
startBlockId.mapId,
startBlockId.reduceId,
toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1),
toBeMerged.map(_.size).sum,
toBeMerged.head.mapIndex)
}
val result = if (doBatchFetch) {
var curBlocks = new ArrayBuffer[FetchBlockInfo]
val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]
val iter = blocks.iterator
def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = {
val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId]
// The last merged block may comes from the input, and we can merge more blocks
// into it, if the map id is the same.
def shouldMergeIntoPreviousBatchBlockId =
mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId
val startReduceId = if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) {
// Remove the previous batch block id as we will add a new one to replace it.
mergedBlockInfo.remove(mergedBlockInfo.length - 1).blockId
.asInstanceOf[ShuffleBlockBatchId].startReduceId
} else {
startBlockId.reduceId
}
FetchBlockInfo(
ShuffleBlockBatchId(
startBlockId.shuffleId,
startBlockId.mapId,
startReduceId,
toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1),
toBeMerged.map(_.size).sum,
toBeMerged.head.mapIndex)
}
val iter = blocks.iterator
while (iter.hasNext) {
val info = iter.next()
val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId]
if (curBlocks.isEmpty) {
curBlocks += info
// It's possible that the input block id is already a batch ID. For example, we merge some
// blocks, and then make fetch requests with the merged blocks according to "max blocks per
// request". The last fetch request may be too small, and we give up and put the remaining
// merged blocks back to the input list.
if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) {
mergedBlockInfo += info
} else {
if (curBlockId.mapId != curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId) {
mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
curBlocks.clear()
if (curBlocks.isEmpty) {
curBlocks += info
} else {
val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId]
val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId
if (curBlockId.mapId != currentMapId) {
mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
curBlocks.clear()
}
curBlocks += info
}
curBlocks += info
}
}
if (curBlocks.nonEmpty) {

View file

@ -341,7 +341,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1)
}
test("fetch continuous blocks in batch respects maxBlocksInFlightPerAddress") {
test("fetch continuous blocks in batch respects maxSize and maxBlocks") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-local-host", 1)
doReturn(localBmId).when(blockManager).blockManagerId
@ -352,9 +352,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 3, 0),
ShuffleBlockId(0, 3, 1),
ShuffleBlockId(0, 3, 2),
ShuffleBlockId(0, 3, 3))
ShuffleBlockId(0, 4, 0),
ShuffleBlockId(0, 4, 1),
ShuffleBlockId(0, 5, 0),
ShuffleBlockId(0, 5, 1),
ShuffleBlockId(0, 5, 2))
val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockBatchId(0, 3, 0, 4) -> createMockManagedBuffer())
ShuffleBlockBatchId(0, 3, 0, 3) -> createMockManagedBuffer(),
ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer(),
ShuffleBlockBatchId(0, 5, 0, 3) -> createMockManagedBuffer())
val transfer = createMockTransfer(mergedRemoteBlocks)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
@ -369,21 +375,27 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
35,
Int.MaxValue,
1,
2,
Int.MaxValue,
true,
false,
metrics,
true)
assert(iterator.hasNext)
val (blockId, inputStream) = iterator.next()
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any())
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = mergedRemoteBlocks(blockId)
verifyBufferRelease(mockBuf, inputStream)
var numResults = 0
while (iterator.hasNext) {
val (blockId, inputStream) = iterator.next()
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = mergedRemoteBlocks(blockId)
verifyBufferRelease(mockBuf, inputStream)
numResults += 1
}
// The first 2 batch block ids are in the same fetch request as they don't exceed the max size
// and max blocks, so 2 requests in total.
verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any())
assert(numResults == 3)
}
test("release current unexhausted buffer in case the task completes early") {