diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index ec2e3dce66..0b7eaa6225 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; +import java.util.LinkedHashMap; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; @@ -81,7 +81,6 @@ public class OneForOneBlockFetcher { TransportConf transportConf, DownloadFileManager downloadFileManager) { this.client = client; - this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; @@ -90,8 +89,10 @@ public class OneForOneBlockFetcher { throw new IllegalArgumentException("Zero-sized blockIds array"); } if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { - this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds); + this.blockIds = new String[blockIds.length]; + this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); } else { + this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } @@ -106,17 +107,16 @@ public class OneForOneBlockFetcher { } /** - * Analyze the pass in blockIds and create FetchShuffleBlocks message. - * The blockIds has been sorted by mapId and reduceId. It's produced in - * org.apache.spark.MapOutputTracker.convertMapStatuses. + * Create FetchShuffleBlocks message and rebuild internal blockIds by + * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsg( + private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( String appId, String execId, String[] blockIds) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - HashMap> mapIdToReduceIds = new HashMap<>(); + LinkedHashMap mapIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { @@ -124,23 +124,36 @@ public class OneForOneBlockFetcher { ", got:" + blockId); } long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToReduceIds.containsKey(mapId)) { - mapIdToReduceIds.put(mapId, new ArrayList<>()); + if (!mapIdToBlocksInfo.containsKey(mapId)) { + mapIdToBlocksInfo.put(mapId, new BlocksInfo()); } - mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); + blocksInfoByMapId.blockIds.add(blockId); + blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4])); + blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet()); + long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); int[][] reduceIdArr = new int[mapIds.length][]; + int blockIdIndex = 0; for (int i = 0; i < mapIds.length; i++) { - reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); + BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); + reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + + // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks + // because the shuffle data's return order should match the `blockIds`'s order to ensure + // blockId and data match. + for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { + this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + } } + assert(blockIdIndex == this.blockIds.length); + return new FetchShuffleBlocks( appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); } @@ -157,6 +170,18 @@ public class OneForOneBlockFetcher { return blockIdParts; } + /** The reduceIds and blocks in a single mapId */ + private class BlocksInfo { + + final ArrayList reduceIds; + final ArrayList blockIds; + + BlocksInfo() { + this.reduceIds = new ArrayList<>(); + this.blockIds = new ArrayList<>(); + } + } + /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ private class ChunkCallback implements ChunkReceivedCallback { @Override diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 285eedb39c..a7eb59d366 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -201,6 +201,48 @@ public class OneForOneBlockFetcherSuite { } } + @Test + public void testFetchShuffleBlocksOrder() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1]))); + blocks.put("shuffle_0_2_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[2]))); + blocks.put("shuffle_0_10_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new FetchShuffleBlocks("app-id", "exec-id", 0, + new long[]{0, 2, 10}, new int[][]{{0}, {1}, {2}}, false), + conf); + + for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) { + String blockId = blockIds[chunkIndex]; + verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId)); + } + } + + @Test + public void testBatchFetchShuffleBlocksOrder() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_1_2", new NioManagedBuffer(ByteBuffer.wrap(new byte[1]))); + blocks.put("shuffle_0_2_2_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[2]))); + blocks.put("shuffle_0_10_3_4", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new FetchShuffleBlocks("app-id", "exec-id", 0, + new long[]{0, 2, 10}, new int[][]{{1, 2}, {2, 3}, {3, 4}}, true), + conf); + + for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) { + String blockId = blockIds[chunkIndex]; + verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId)); + } + } + /** * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which * simply returns the given (BlockId, Block) pairs.