[SPARK-34534] Fix blockIds order when use FetchShuffleBlocks to fetch blocks
### What changes were proposed in this pull request? Fix a problems which can lead to data correctness after part blocks retry in `OneForOneBlockFetcher` when use `FetchShuffleBlocks` . ### Why are the changes needed? This is a data correctness bug, It's is no problems when use old protocol to send `OpenBlocks` before fetch chunks in `OneForOneBlockFetcher`; In latest branch, `OpenBlocks` has been replaced to `FetchShuffleBlocks`. Howerver, `FetchShuffleBlocks` read shuffle blocks order is not the same as `blockIds` in `OneForOneBlockFetcher`; the `blockIds` is used to match blockId with shuffle data with index, now it is out of order; It will lead to read wrong block chunk when some blocks fetch failed in `OneForOneBlockFetcher`, it will retry the rest of the blocks in `blockIds` based on the `blockIds`'s order. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Closes #31643 from seayoun/yuhaiyang_fix_use_FetchShuffleBlocks_order. Lead-authored-by: yuhaiyang <yuhaiyang@yuhaiyangs-MacBook-Pro.local> Co-authored-by: yuhaiyang <yuhaiyang@172.19.25.126> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
ecf4811764
commit
4e43819611
|
@ -21,7 +21,7 @@ import java.io.IOException;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
|
||||
import com.google.common.primitives.Ints;
|
||||
import com.google.common.primitives.Longs;
|
||||
|
@ -81,7 +81,6 @@ public class OneForOneBlockFetcher {
|
|||
TransportConf transportConf,
|
||||
DownloadFileManager downloadFileManager) {
|
||||
this.client = client;
|
||||
this.blockIds = blockIds;
|
||||
this.listener = listener;
|
||||
this.chunkCallback = new ChunkCallback();
|
||||
this.transportConf = transportConf;
|
||||
|
@ -90,8 +89,10 @@ public class OneForOneBlockFetcher {
|
|||
throw new IllegalArgumentException("Zero-sized blockIds array");
|
||||
}
|
||||
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
|
||||
this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
|
||||
this.blockIds = new String[blockIds.length];
|
||||
this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
|
||||
} else {
|
||||
this.blockIds = blockIds;
|
||||
this.message = new OpenBlocks(appId, execId, blockIds);
|
||||
}
|
||||
}
|
||||
|
@ -106,17 +107,16 @@ public class OneForOneBlockFetcher {
|
|||
}
|
||||
|
||||
/**
|
||||
* Analyze the pass in blockIds and create FetchShuffleBlocks message.
|
||||
* The blockIds has been sorted by mapId and reduceId. It's produced in
|
||||
* org.apache.spark.MapOutputTracker.convertMapStatuses.
|
||||
* Create FetchShuffleBlocks message and rebuild internal blockIds by
|
||||
* analyzing the pass in blockIds.
|
||||
*/
|
||||
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
|
||||
private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
|
||||
String appId, String execId, String[] blockIds) {
|
||||
String[] firstBlock = splitBlockId(blockIds[0]);
|
||||
int shuffleId = Integer.parseInt(firstBlock[1]);
|
||||
boolean batchFetchEnabled = firstBlock.length == 5;
|
||||
|
||||
HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
|
||||
LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
|
||||
for (String blockId : blockIds) {
|
||||
String[] blockIdParts = splitBlockId(blockId);
|
||||
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
|
||||
|
@ -124,23 +124,36 @@ public class OneForOneBlockFetcher {
|
|||
", got:" + blockId);
|
||||
}
|
||||
long mapId = Long.parseLong(blockIdParts[2]);
|
||||
if (!mapIdToReduceIds.containsKey(mapId)) {
|
||||
mapIdToReduceIds.put(mapId, new ArrayList<>());
|
||||
if (!mapIdToBlocksInfo.containsKey(mapId)) {
|
||||
mapIdToBlocksInfo.put(mapId, new BlocksInfo());
|
||||
}
|
||||
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
|
||||
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
|
||||
blocksInfoByMapId.blockIds.add(blockId);
|
||||
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
|
||||
if (batchFetchEnabled) {
|
||||
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
|
||||
// FetchShuffleBlocks to store the start and end reduce id for range
|
||||
// [startReduceId, endReduceId).
|
||||
assert(blockIdParts.length == 5);
|
||||
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
|
||||
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
|
||||
}
|
||||
}
|
||||
long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
|
||||
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
|
||||
int[][] reduceIdArr = new int[mapIds.length][];
|
||||
int blockIdIndex = 0;
|
||||
for (int i = 0; i < mapIds.length; i++) {
|
||||
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
|
||||
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
|
||||
reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
|
||||
|
||||
// The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
|
||||
// because the shuffle data's return order should match the `blockIds`'s order to ensure
|
||||
// blockId and data match.
|
||||
for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
|
||||
this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
|
||||
}
|
||||
}
|
||||
assert(blockIdIndex == this.blockIds.length);
|
||||
|
||||
return new FetchShuffleBlocks(
|
||||
appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
|
||||
}
|
||||
|
@ -157,6 +170,18 @@ public class OneForOneBlockFetcher {
|
|||
return blockIdParts;
|
||||
}
|
||||
|
||||
/** The reduceIds and blocks in a single mapId */
|
||||
private class BlocksInfo {
|
||||
|
||||
final ArrayList<Integer> reduceIds;
|
||||
final ArrayList<String> blockIds;
|
||||
|
||||
BlocksInfo() {
|
||||
this.reduceIds = new ArrayList<>();
|
||||
this.blockIds = new ArrayList<>();
|
||||
}
|
||||
}
|
||||
|
||||
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
|
||||
private class ChunkCallback implements ChunkReceivedCallback {
|
||||
@Override
|
||||
|
|
|
@ -201,6 +201,48 @@ public class OneForOneBlockFetcherSuite {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFetchShuffleBlocksOrder() {
|
||||
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
|
||||
blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
|
||||
blocks.put("shuffle_0_2_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[2])));
|
||||
blocks.put("shuffle_0_10_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3])));
|
||||
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
|
||||
|
||||
BlockFetchingListener listener = fetchBlocks(
|
||||
blocks,
|
||||
blockIds,
|
||||
new FetchShuffleBlocks("app-id", "exec-id", 0,
|
||||
new long[]{0, 2, 10}, new int[][]{{0}, {1}, {2}}, false),
|
||||
conf);
|
||||
|
||||
for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) {
|
||||
String blockId = blockIds[chunkIndex];
|
||||
verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBatchFetchShuffleBlocksOrder() {
|
||||
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
|
||||
blocks.put("shuffle_0_0_1_2", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
|
||||
blocks.put("shuffle_0_2_2_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[2])));
|
||||
blocks.put("shuffle_0_10_3_4", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3])));
|
||||
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
|
||||
|
||||
BlockFetchingListener listener = fetchBlocks(
|
||||
blocks,
|
||||
blockIds,
|
||||
new FetchShuffleBlocks("app-id", "exec-id", 0,
|
||||
new long[]{0, 2, 10}, new int[][]{{1, 2}, {2, 3}, {3, 4}}, true),
|
||||
conf);
|
||||
|
||||
for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) {
|
||||
String blockId = blockIds[chunkIndex];
|
||||
verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
|
||||
* simply returns the given (BlockId, Block) pairs.
|
||||
|
|
Loading…
Reference in a new issue