From a9b32b390c69fe8dc4bbc812894d67175b873ce6 Mon Sep 17 00:00:00 2001 From: Venkata krishnan Sowrirajan Date: Sun, 1 Aug 2021 23:16:33 -0500 Subject: [PATCH] [SPARK-32923][CORE][SHUFFLE] Handle indeterminate stage retries for push-based shuffle [[SPARK-23243](https://issues.apache.org/jira/browse/SPARK-23243)] and [[SPARK-25341](https://issues.apache.org/jira/browse/SPARK-25341)] addressed cases of stage retries for indeterminate stage involving operations like repartition. This PR addresses the same issues in the context of push-based shuffle. Currently there is no way to distinguish the current execution of a stage for a shuffle ID. Therefore the changes explained below are necessary. Core changes are summarized as follows: 1. Introduce a new variable `shuffleMergeId` in `ShuffleDependency` which is monotonically increasing value tracking the temporal ordering of execution of for a shuffle ID. 2. Correspondingly make changes in the push-based shuffle protocol layer in `MergedShuffleFileManager`, `BlockStoreClient` passing the `shuffleMergeId` in order to keep track of the shuffle output in separate files on the shuffle service side. 3. `DAGScheduler` increments the `shuffleMergeId` tracked in `ShuffleDependency` in the cases of a indeterministic stage execution 4. Deterministic stage will have `shuffleMergeId` set to 0 as no special handling is needed in this case and indeterminate stage will have `shuffleMergeId` starting from 1. New protocol changes are needed due to the reasons explained above. No Added new unit tests in `RemoteBlockPushResolverSuite, DAGSchedulerSuite, BlockIdSuite, ErrorHandlerSuite` Closes #33034 from venkata91/SPARK-32923. Authored-by: Venkata krishnan Sowrirajan Signed-off-by: Mridul Muralidharan gmail.com> (cherry picked from commit c039d998128dd0dab27f43e7de083a71b9d1cfcf) Signed-off-by: Mridul Muralidharan --- .../spark/network/client/TransportClient.java | 6 +- .../protocol/MergedBlockMetaRequest.java | 23 +- .../network/TransportRequestHandlerSuite.java | 4 +- .../network/shuffle/BlockStoreClient.java | 6 + .../spark/network/shuffle/ErrorHandler.java | 53 +- .../network/shuffle/ExternalBlockHandler.java | 90 ++- .../shuffle/ExternalBlockStoreClient.java | 22 +- .../shuffle/MergedBlocksMetaListener.java | 8 +- .../shuffle/MergedShuffleFileManager.java | 17 +- .../shuffle/OneForOneBlockFetcher.java | 126 +++-- .../network/shuffle/OneForOneBlockPusher.java | 5 +- .../shuffle/RemoteBlockPushResolver.java | 472 +++++++++++----- .../protocol/FetchShuffleBlockChunks.java | 19 +- .../protocol/FinalizeShuffleMerge.java | 17 +- .../shuffle/protocol/MergeStatuses.java | 18 +- .../shuffle/protocol/PushBlockStream.java | 15 +- .../network/shuffle/ErrorHandlerSuite.java | 34 +- .../shuffle/ExternalBlockHandlerSuite.java | 29 +- .../shuffle/OneForOneBlockFetcherSuite.java | 42 +- .../shuffle/OneForOneBlockPusherSuite.java | 67 +-- .../shuffle/RemoteBlockPushResolverSuite.java | 527 ++++++++++++------ .../FetchShuffleBlockChunksSuite.java | 4 +- .../scala/org/apache/spark/Dependency.scala | 32 +- .../org/apache/spark/MapOutputTracker.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 5 +- .../apache/spark/scheduler/MergeStatus.scala | 19 +- .../shuffle/IndexShuffleBlockResolver.scala | 25 +- .../spark/shuffle/ShuffleBlockPusher.scala | 22 +- .../spark/shuffle/ShuffleBlockResolver.scala | 10 +- .../org/apache/spark/storage/BlockId.scala | 77 ++- .../apache/spark/storage/BlockManager.scala | 4 +- .../spark/storage/PushBasedFetchHelper.scala | 41 +- .../storage/ShuffleBlockFetcherIterator.scala | 35 +- .../apache/spark/MapOutputTrackerSuite.scala | 22 +- .../spark/scheduler/DAGSchedulerSuite.scala | 78 ++- .../shuffle/ShuffleBlockPusherSuite.scala | 15 +- .../sort/IndexShuffleBlockResolverSuite.scala | 18 +- .../apache/spark/storage/BlockIdSuite.scala | 47 +- .../ShuffleBlockFetcherIteratorSuite.scala | 161 +++--- 39 files changed, 1502 insertions(+), 722 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a50c04cf80..dd2fdb08ee 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -206,12 +206,15 @@ public class TransportClient implements Closeable { * * @param appId applicationId. * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param callback callback the handle the reply. */ public void sendMergedBlockMetaReq( String appId, int shuffleId, + int shuffleMergeId, int reduceId, MergedBlockMetaResponseCallback callback) { long requestId = requestId(); @@ -222,7 +225,8 @@ public class TransportClient implements Closeable { handler.addRpcRequest(requestId, callback); RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush( - new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener); + new MergedBlockMetaRequest(requestId, appId, shuffleId, shuffleMergeId, + reduceId)).addListener(listener); } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java index cf7c22d241..c85d104fdd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java @@ -32,13 +32,20 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe public final long requestId; public final String appId; public final int shuffleId; + public final int shuffleMergeId; public final int reduceId; - public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) { + public MergedBlockMetaRequest( + long requestId, + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId) { super(null, false); this.requestId = requestId; this.appId = appId; this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; this.reduceId = reduceId; } @@ -49,7 +56,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe @Override public int encodedLength() { - return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4; + return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; } @Override @@ -57,6 +64,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe buf.writeLong(requestId); Encoders.Strings.encode(buf, appId); buf.writeInt(shuffleId); + buf.writeInt(shuffleMergeId); buf.writeInt(reduceId); } @@ -64,21 +72,23 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe long requestId = buf.readLong(); String appId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); + int shuffleMergeId = buf.readInt(); int reduceId = buf.readInt(); - return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId); + return new MergedBlockMetaRequest(requestId, appId, shuffleId, shuffleMergeId, reduceId); } @Override public int hashCode() { - return Objects.hashCode(requestId, appId, shuffleId, reduceId); + return Objects.hashCode(requestId, appId, shuffleId, shuffleMergeId, reduceId); } @Override public boolean equals(Object other) { if (other instanceof MergedBlockMetaRequest) { MergedBlockMetaRequest o = (MergedBlockMetaRequest) other; - return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId - && Objects.equal(appId, o.appId); + return requestId == o.requestId && shuffleId == o.shuffleId && + shuffleMergeId == o.shuffleMergeId && reduceId == o.reduceId && + Objects.equal(appId, o.appId); } return false; } @@ -89,6 +99,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe .append("requestId", requestId) .append("appId", appId) .append("shuffleId", shuffleId) + .append("shuffleMergeId", shuffleMergeId) .append("reduceId", reduceId) .toString(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index b3befb8baf..70c7a1684f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -152,14 +152,14 @@ public class TransportRequestHandlerSuite { TransportClient reverseClient = mock(TransportClient.class); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L, null); - MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0); + MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0, 0); requestHandler.handle(validMetaReq); assertEquals(1, responseAndPromisePairs.size()); assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess); assertEquals(2, ((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks()); - MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1); + MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 0, 1); requestHandler.handle(invalidMetaReq); assertEquals(2, responseAndPromisePairs.size()); assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index b6852130c9..829884645d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -167,6 +167,8 @@ public abstract class BlockStoreClient implements Closeable { * @param host host of shuffle server * @param port port of shuffle server. * @param shuffleId shuffle ID of the shuffle to be finalized + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param listener the listener to receive MergeStatuses * * @since 3.1.0 @@ -175,6 +177,7 @@ public abstract class BlockStoreClient implements Closeable { String host, int port, int shuffleId, + int shuffleMergeId, MergeFinalizerListener listener) { throw new UnsupportedOperationException(); } @@ -185,6 +188,8 @@ public abstract class BlockStoreClient implements Closeable { * @param host the host of the remote node. * @param port the port of the remote node. * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param listener the listener to receive chunk counts. * @@ -194,6 +199,7 @@ public abstract class BlockStoreClient implements Closeable { String host, int port, int shuffleId, + int shuffleMergeId, int reduceId, MergedBlocksMetaListener listener) { throw new UnsupportedOperationException(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java index a75887525e..0149ad7434 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java @@ -55,12 +55,14 @@ public interface ErrorHandler { class BlockPushErrorHandler implements ErrorHandler { /** * String constant used for generating exception messages indicating a block to be merged - * arrives too late on the server side, and also for later checking such exceptions on the - * client side. When we get a block push failure because of the block arrives too late, we - * will not retry pushing the block nor log the exception on the client side. + * arrives too late or stale block push in the case of indeterminate stage retries on the + * server side, and also for later checking such exceptions on the client side. When we get + * a block push failure because of the block push being stale or arrives too late, we will + * not retry pushing the block nor log the exception on the client side. */ - public static final String TOO_LATE_MESSAGE_SUFFIX = - "received after merged shuffle is finalized"; + public static final String TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX = + "received after merged shuffle is finalized or stale block push as shuffle blocks of a" + + " higher shuffleMergeId for the shuffle is being pushed"; /** * String constant used for generating exception messages indicating the server couldn't @@ -81,25 +83,54 @@ public interface ErrorHandler { public static final String IOEXCEPTIONS_EXCEEDED_THRESHOLD_PREFIX = "IOExceptions exceeded the threshold"; + /** + * String constant used for generating exception messages indicating the server rejecting a + * shuffle finalize request since shuffle blocks of a higher shuffleMergeId for a shuffle is + * already being pushed. This typically happens in the case of indeterminate stage retries + * where if a stage attempt fails then the entirety of the shuffle output needs to be rolled + * back. For more details refer SPARK-23243, SPARK-25341 and SPARK-32923. + */ + public static final String STALE_SHUFFLE_FINALIZE_SUFFIX = + "stale shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the" + + " shuffle is already being pushed"; + @Override public boolean shouldRetryError(Throwable t) { // If it is a connection time-out or a connection closed exception, no need to retry. // If it is a FileNotFoundException originating from the client while pushing the shuffle - // blocks to the server, even then there is no need to retry. We will still log this exception - // once which helps with debugging. + // blocks to the server, even then there is no need to retry. We will still log this + // exception once which helps with debugging. if (t.getCause() != null && (t.getCause() instanceof ConnectException || t.getCause() instanceof FileNotFoundException)) { return false; } - // If the block is too late, there is no need to retry it - return !Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX); + + String errorStackTrace = Throwables.getStackTraceAsString(t); + // If the block is too late or stale block push, there is no need to retry it + return !errorStackTrace.contains(TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX); } @Override public boolean shouldLogError(Throwable t) { String errorStackTrace = Throwables.getStackTraceAsString(t); - return !errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) && - !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX); + return !(errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) || + errorStackTrace.contains(TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)); + } + } + + class BlockFetchErrorHandler implements ErrorHandler { + public static final String STALE_SHUFFLE_BLOCK_FETCH = + "stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for the" + + " shuffle is available"; + + @Override + public boolean shouldRetryError(Throwable t) { + return !Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_BLOCK_FETCH); + } + + @Override + public boolean shouldLogError(Throwable t) { + return !Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_BLOCK_FETCH); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index e0f2e95950..cfabcd5ba4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -218,7 +218,8 @@ public class ExternalBlockHandler extends RpcHandler callback.onSuccess(statuses.toByteBuffer()); } catch(IOException e) { throw new RuntimeException(String.format("Error while finalizing shuffle merge " - + "for application %s shuffle %d", msg.appId, msg.shuffleId), e); + + "for application %s shuffle %d with shuffleMergeId %d", msg.appId, msg.shuffleId, + msg.shuffleMergeId), e); } finally { responseDelayContext.stop(); } @@ -237,7 +238,7 @@ public class ExternalBlockHandler extends RpcHandler checkAuth(client, metaRequest.appId); MergedBlockMeta mergedMeta = mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId, - metaRequest.reduceId); + metaRequest.shuffleMergeId, metaRequest.reduceId); logger.debug( "Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ", metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId, @@ -364,7 +365,6 @@ public class ExternalBlockHandler extends RpcHandler private int index = 0; private final Function blockDataForIndexFn; private final int size; - private boolean requestForMergedBlockChunks; ManagedBufferIterator(OpenBlocks msg) { String appId = msg.appId; @@ -377,13 +377,14 @@ public class ExternalBlockHandler extends RpcHandler size = mapIdAndReduceIds.length; blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); - } else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) { - requestForMergedBlockChunks = true; + } else if (blockId0Parts.length == 5 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) { final int shuffleId = Integer.parseInt(blockId0Parts[1]); - final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); + final int shuffleMergeId = Integer.parseInt(blockId0Parts[2]); + final int[] reduceIdAndChunkIds = shuffleReduceIdAndChunkIds(blockIds, shuffleId, + shuffleMergeId); size = reduceIdAndChunkIds.length; blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId, - reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]); + shuffleMergeId, reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]); } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) { final int[] rddAndSplitIds = rddAndSplitIds(blockIds); size = rddAndSplitIds.length; @@ -407,27 +408,64 @@ public class ExternalBlockHandler extends RpcHandler return rddAndSplitIds; } + /** + * @param blockIds Regular shuffle blockIds starts with SHUFFLE_BLOCK_ID to be parsed + * @param shuffleId shuffle blocks shuffleId + * @return mapId and reduceIds of the shuffle blocks in the same order as that of the blockIds + * + * Regular shuffle blocks format should be shuffle_$shuffleId_$mapId_$reduceId + */ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { - // For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds. - // For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds. - final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length]; + final int[] mapIdAndReduceIds = new int[2 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 - || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID)) - || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) { + if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_BLOCK_ID)) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); } if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockIds[i]); } - // For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId. - primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]); - // For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId. - primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + // mapId + mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); + // reduceId + mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); } - return primaryIdAndSecondaryIds; + return mapIdAndReduceIds; + } + + /** + * @param blockIds Shuffle merged chunks starts with SHUFFLE_CHUNK_ID to be parsed + * @param shuffleId shuffle blocks shuffleId + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. + * @return reduceId and chunkIds of the shuffle chunks in the same order as that of the + * blockIds + * + * Shuffle merged chunks format should be + * shuffleChunk_$shuffleId_$shuffleMergeId_$reduceId_$chunkId + */ + private int[] shuffleReduceIdAndChunkIds( + String[] blockIds, + int shuffleId, + int shuffleMergeId) { + final int[] reduceIdAndChunkIds = new int[2 * blockIds.length]; + for(int i = 0; i < blockIds.length; i++) { + String[] blockIdParts = blockIds[i].split("_"); + if (blockIdParts.length != 5 || !blockIdParts[0].equals(SHUFFLE_CHUNK_ID)) { + throw new IllegalArgumentException("Unexpected shuffle chunk id format: " + blockIds[i]); + } + if (Integer.parseInt(blockIdParts[1]) != shuffleId || + Integer.parseInt(blockIdParts[2]) != shuffleMergeId) { + throw new IllegalArgumentException(String.format("Expected shuffleId = %s" + + " and shuffleMergeId = %s but got %s", shuffleId, shuffleMergeId, blockIds[i])); + } + // reduceId + reduceIdAndChunkIds[2 * i] = Integer.parseInt(blockIdParts[3]); + // chunkId + reduceIdAndChunkIds[2 * i + 1] = Integer.parseInt(blockIdParts[4]); + } + return reduceIdAndChunkIds; } @Override @@ -511,12 +549,14 @@ public class ExternalBlockHandler extends RpcHandler private final String appId; private final int shuffleId; + private final int shuffleMergeId; private final int[] reduceIds; private final int[][] chunkIds; ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) { appId = msg.appId; shuffleId = msg.shuffleId; + shuffleMergeId = msg.shuffleMergeId; reduceIds = msg.reduceIds; chunkIds = msg.chunkIds; // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks @@ -533,7 +573,7 @@ public class ExternalBlockHandler extends RpcHandler @Override public ManagedBuffer next() { ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData( - appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx])); + appId, shuffleId, shuffleMergeId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx])); if (chunkIdx < chunkIds[reduceIdx].length - 1) { chunkIdx += 1; } else { @@ -580,12 +620,20 @@ public class ExternalBlockHandler extends RpcHandler @Override public ManagedBuffer getMergedBlockData( - String appId, int shuffleId, int reduceId, int chunkId) { + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId, + int chunkId) { throw new UnsupportedOperationException("Cannot handle shuffle block merge"); } @Override - public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) { + public MergedBlockMeta getMergedBlockMeta( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId) { throw new UnsupportedOperationException("Cannot handle shuffle block merge"); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index f88915b504..eb2d118b7d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -172,12 +172,14 @@ public class ExternalBlockStoreClient extends BlockStoreClient { String host, int port, int shuffleId, + int shuffleMergeId, MergeFinalizerListener listener) { checkInit(); try { TransportClient client = clientFactory.createClient(host, port); ByteBuffer finalizeShuffleMerge = - new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId).toByteBuffer(); + new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId, + shuffleMergeId).toByteBuffer(); client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { @@ -202,29 +204,31 @@ public class ExternalBlockStoreClient extends BlockStoreClient { String host, int port, int shuffleId, + int shuffleMergeId, int reduceId, MergedBlocksMetaListener listener) { checkInit(); - logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port, - shuffleId, reduceId); + logger.debug("Get merged blocks meta from {}:{} for shuffleId {} shuffleMergeId {}" + + " reduceId {}", host, port, shuffleId, shuffleMergeId, reduceId); try { TransportClient client = clientFactory.createClient(host, port); - client.sendMergedBlockMetaReq(appId, shuffleId, reduceId, + client.sendMergedBlockMetaReq(appId, shuffleId, shuffleMergeId, reduceId, new MergedBlockMetaResponseCallback() { @Override public void onSuccess(int numChunks, ManagedBuffer buffer) { - logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}", - shuffleId, reduceId); - listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer)); + logger.trace("Successfully got merged block meta for shuffleId {} shuffleMergeId {}" + + " reduceId {}", shuffleId, shuffleMergeId, reduceId); + listener.onSuccess(shuffleId, reduceId, shuffleMergeId, + new MergedBlockMeta(numChunks, buffer)); } @Override public void onFailure(Throwable e) { - listener.onFailure(shuffleId, reduceId, e); + listener.onFailure(shuffleId, shuffleMergeId, reduceId, e); } }); } catch (Exception e) { - listener.onFailure(shuffleId, reduceId, e); + listener.onFailure(shuffleId, shuffleMergeId, reduceId, e); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java index 0e277d3303..cea76ddb1a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java @@ -30,17 +30,21 @@ public interface MergedBlocksMetaListener extends EventListener { * Called after successfully receiving the meta of a merged block. * * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param meta contains meta information of a merged block. */ - void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta); + void onSuccess(int shuffleId, int shuffleMergeId, int reduceId, MergedBlockMeta meta); /** * Called when there is an exception while fetching the meta of a merged block. * * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param exception exception getting chunk counts. */ - void onFailure(int shuffleId, int reduceId, Throwable exception); + void onFailure(int shuffleId, int shuffleMergeId, int reduceId, Throwable exception); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java index 4ce6a478ff..630386d97d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java @@ -85,21 +85,34 @@ public interface MergedShuffleFileManager { * * @param appId application ID * @param shuffleId shuffle ID + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reducer ID * @param chunkId merged shuffle file chunk ID * @return The {@link ManagedBuffer} for the given merged shuffle chunk */ - ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId); + ManagedBuffer getMergedBlockData( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId, + int chunkId); /** * Get the meta information of a merged block. * * @param appId application ID * @param shuffleId shuffle ID + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reducer ID * @return meta information of a merged block */ - MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId); + MergedBlockMeta getMergedBlockMeta( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId); /** * Get the local directories which stores the merged shuffle files. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 2bf16b0097..a98e2029f0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -22,7 +22,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; -import java.util.Set; +import java.util.Map; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; @@ -56,6 +56,8 @@ public class OneForOneBlockFetcher { private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_"; private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_"; + private static final String SHUFFLE_BLOCK_SPLIT = "shuffle"; + private static final String SHUFFLE_CHUNK_SPLIT = "shuffleChunk"; private final TransportClient client; private final BlockTransferMessage message; @@ -125,63 +127,87 @@ public class OneForOneBlockFetcher { String execId, String[] blockIds) { if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { - return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + return createFetchShuffleChunksMsg(appId, execId, blockIds); } else { - return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + return createFetchShuffleBlocksMsg(appId, execId, blockIds); } } - /** - * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by - * analyzing the passed in blockIds. - */ - private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + private AbstractFetchShuffleBlocks createFetchShuffleBlocksMsg( String appId, String execId, - String[] blockIds, - boolean areMergedChunks) { + String[] blockIds) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - - // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId - // is reduceId. - LinkedHashMap primaryIdToBlocksInfo = new LinkedHashMap<>(); + Map mapIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { - throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + - ", got:" + blockId); + throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - Number primaryId; - if (!areMergedChunks) { - primaryId = Long.parseLong(blockIdParts[2]); - } else { - primaryId = Integer.parseInt(blockIdParts[2]); - } - BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, + + long mapId = Long.parseLong(blockIdParts[2]); + BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.computeIfAbsent(mapId, id -> new BlocksInfo()); - blocksInfoByPrimaryId.blockIds.add(blockId); - // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a - // shuffleChunk block, then blockIdParts[3] = chunkId - blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); + blocksInfoByMapId.blockIds.add(blockId); + blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[3])); + if (batchFetchEnabled) { - // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block. // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); // blockIdParts[4] is the end reduce id for the batch range - blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); + blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[4])); } } + + int[][] reduceIdsArray = getSecondaryIds(mapIdToBlocksInfo); + long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); + return new FetchShuffleBlocks( + appId, execId, shuffleId, mapIds, reduceIdsArray, batchFetchEnabled); + } + + private AbstractFetchShuffleBlocks createFetchShuffleChunksMsg( + String appId, + String execId, + String[] blockIds) { + String[] firstBlock = splitBlockId(blockIds[0]); + int shuffleId = Integer.parseInt(firstBlock[1]); + int shuffleMergeId = Integer.parseInt(firstBlock[2]); + + Map reduceIdToBlocksInfo = new LinkedHashMap<>(); + for (String blockId : blockIds) { + String[] blockIdParts = splitBlockId(blockId); + if (Integer.parseInt(blockIdParts[1]) != shuffleId || + Integer.parseInt(blockIdParts[2]) != shuffleMergeId) { + throw new IllegalArgumentException(String.format("Expected shuffleId = %s and" + + " shuffleMergeId = %s but got %s", shuffleId, shuffleMergeId, blockId)); + } + + int reduceId = Integer.parseInt(blockIdParts[3]); + BlocksInfo blocksInfoByReduceId = reduceIdToBlocksInfo.computeIfAbsent(reduceId, + id -> new BlocksInfo()); + blocksInfoByReduceId.blockIds.add(blockId); + blocksInfoByReduceId.ids.add(Integer.parseInt(blockIdParts[4])); + } + + int[][] chunkIdsArray = getSecondaryIds(reduceIdToBlocksInfo); + int[] reduceIds = Ints.toArray(reduceIdToBlocksInfo.keySet()); + + return new FetchShuffleBlockChunks(appId, execId, shuffleId, shuffleMergeId, reduceIds, + chunkIdsArray); + } + + private int[][] getSecondaryIds(Map primaryIdsToBlockInfo) { // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, // secondaryIds are chunkIds. - int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][]; + int[][] secondaryIds = new int[primaryIdsToBlockInfo.size()][]; int blockIdIndex = 0; int secIndex = 0; - for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) { - secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids); + for (BlocksInfo blocksInfo: primaryIdsToBlockInfo.values()) { + secondaryIds[secIndex++] = Ints.toArray(blocksInfo.ids); // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ // FetchShuffleBlockChunks because the shuffle data's return order should match the @@ -191,35 +217,27 @@ public class OneForOneBlockFetcher { } } assert(blockIdIndex == this.blockIds.length); - Set primaryIds = primaryIdToBlocksInfo.keySet(); - if (!areMergedChunks) { - long[] mapIds = Longs.toArray(primaryIds); - return new FetchShuffleBlocks( - appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled); - } else { - int[] reduceIds = Ints.toArray(primaryIds); - return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray); - } + return secondaryIds; } - /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */ + /** + * Split the blockId and return accordingly + * shuffleChunk - return shuffleId, shuffleMergeId, reduceId and chunkIds + * shuffle block - return shuffleId, mapId, reduceId + * shuffle batch block - return shuffleId, mapId, begin reduceId and end reduceId + */ private String[] splitBlockId(String blockId) { String[] blockIdParts = blockId.split("_"); - // For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId. - // For single block id, the format contains shuffleId, mapId, educeId. - // For single block chunk id, the format contains shuffleId, reduceId, chunkId. if (blockIdParts.length < 4 || blockIdParts.length > 5) { - throw new IllegalArgumentException( - "Unexpected shuffle block id format: " + blockId); + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId); } - if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException( - "Unexpected shuffle block id format: " + blockId); + if (blockIdParts.length == 4 && !blockIdParts[0].equals(SHUFFLE_BLOCK_SPLIT)) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId); } - if (blockIdParts.length == 4 && - !(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) { - throw new IllegalArgumentException( - "Unexpected shuffle block id format: " + blockId); + if (blockIdParts.length == 5 && + !(blockIdParts[0].equals(SHUFFLE_BLOCK_SPLIT) || + blockIdParts[0].equals(SHUFFLE_CHUNK_SPLIT))) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId); } return blockIdParts; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java index 0e1c59f352..f9d313c254 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java @@ -135,13 +135,14 @@ public class OneForOneBlockPusher { assert buffers.containsKey(blockIds[i]) : "Could not find the block buffer for block " + blockIds[i]; String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) { + if (blockIdParts.length != 5 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) { throw new IllegalArgumentException( "Unexpected shuffle push block id format: " + blockIds[i]); } ByteBuffer header = new PushBlockStream(appId, appAttemptId, Integer.parseInt(blockIdParts[1]), - Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]) , i).toByteBuffer(); + Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]), + Integer.parseInt(blockIdParts[4]), i).toByteBuffer(); client.uploadStream(new NioManagedBuffer(header), buffers.get(blockIds[i]), new BlockPushCallback(i, blockIds[i])); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java index f88cfee105..cc7d4dbe00 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -28,6 +28,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; @@ -78,6 +79,15 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { public static final String MERGE_DIR_KEY = "mergeDir"; public static final String ATTEMPT_ID_KEY = "attemptId"; private static final int UNDEFINED_ATTEMPT_ID = -1; + // Shuffles of determinate stages will have shuffleMergeId set to 0 + private static final int DETERMINATE_SHUFFLE_MERGE_ID = 0; + + // ConcurrentHashMap doesn't allow null for keys or values which is why this is required. + // Marker to identify finalized indeterminate shuffle partitions in the case of indeterminate + // stage retries. + @VisibleForTesting + public static final Map INDETERMINATE_SHUFFLE_FINALIZED = + Collections.emptyMap(); /** * A concurrent hashmap where the key is the applicationId, and the value includes @@ -128,50 +138,79 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } /** - * Given the appShuffleInfo, shuffleId and reduceId that uniquely identifies a given shuffle - * partition of an application, retrieves the associated metadata. If not present and the - * corresponding merged shuffle does not exist, initializes the metadata. + * Given the appShuffleInfo, shuffleId, shuffleMergeId and reduceId that uniquely identifies + * a given shuffle partition of an application, retrieves the associated metadata. If not + * present and the corresponding merged shuffle does not exist, initializes the metadata. */ private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo( AppShuffleInfo appShuffleInfo, int shuffleId, - int reduceId) { - File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, reduceId); - ConcurrentMap> partitions = - appShuffleInfo.partitions; - Map shufflePartitions = - partitions.compute(shuffleId, (id, map) -> { - if (map == null) { + int shuffleMergeId, + int reduceId) throws StaleBlockPushException { + ConcurrentMap shuffles = appShuffleInfo.shuffles; + AppShuffleMergePartitionsInfo shufflePartitionsWithMergeId = + shuffles.compute(shuffleId, (id, appShuffleMergePartitionsInfo) -> { + if (appShuffleMergePartitionsInfo == null) { + File dataFile = + appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId); // If this partition is already finalized then the partitions map will not contain the - // shuffleId but the data file would exist. In that case the block is considered late. + // shuffleId for determinate stages but the data file would exist. + // In that case the block is considered late. In the case of indeterminate stages, most + // recent shuffleMergeId finalized would be pointing to INDETERMINATE_SHUFFLE_FINALIZED if (dataFile.exists()) { return null; + } else { + logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}" + + " with shuffleMergeId {} for application {}_{}", shuffleId, shuffleMergeId, + appShuffleInfo.appId, appShuffleInfo.attemptId); + return new AppShuffleMergePartitionsInfo(shuffleMergeId, false); } - return new ConcurrentHashMap<>(); } else { - return map; + // Reject the request as we have already seen a higher shuffleMergeId than the + // current incoming one + int latestShuffleMergeId = appShuffleMergePartitionsInfo.shuffleMergeId; + if (latestShuffleMergeId > shuffleMergeId) { + throw new StaleBlockPushException(String.format("Rejecting shuffle blocks push request" + + " for shuffle %s with shuffleMergeId %s for application %s_%s as a higher" + + " shuffleMergeId %s request is already seen", shuffleId, shuffleMergeId, + appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId)); + } else if (latestShuffleMergeId == shuffleMergeId) { + return appShuffleMergePartitionsInfo; + } else { + // Higher shuffleMergeId seen for the shuffle ID meaning new stage attempt is being + // run for the shuffle ID. Close and clean up old shuffleMergeId files, + // happens in the indeterminate stage retries + logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}" + + " with shuffleMergeId {} for application {}_{} since it is higher than the" + + " latest shuffleMergeId {} already seen", shuffleId, shuffleMergeId, + appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId); + mergedShuffleCleaner.execute(() -> + closeAndDeletePartitionFiles(appShuffleMergePartitionsInfo.shuffleMergePartitions)); + return new AppShuffleMergePartitionsInfo(shuffleMergeId, false); + } } }); - if (shufflePartitions == null) { + + // It only gets here when the shuffle is already finalized. + if (null == shufflePartitionsWithMergeId || + INDETERMINATE_SHUFFLE_FINALIZED == shufflePartitionsWithMergeId.shuffleMergePartitions) { return null; } - return shufflePartitions.computeIfAbsent(reduceId, key -> { - // It only gets here when the key is not present in the map. This could either - // be the first time the merge manager receives a pushed block for a given application - // shuffle partition, or after the merged shuffle file is finalized. We handle these - // two cases accordingly by checking if the file already exists. + Map shuffleMergePartitions = + shufflePartitionsWithMergeId.shuffleMergePartitions; + return shuffleMergePartitions.computeIfAbsent(reduceId, key -> { + // It only gets here when the key is not present in the map. The first time the merge + // manager receives a pushed block for a given application shuffle partition. + File dataFile = + appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId); File indexFile = - appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId); + appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId); File metaFile = - appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId); + appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId); try { - if (dataFile.exists()) { - return null; - } else { - return newAppShufflePartitionInfo( - appShuffleInfo.appId, shuffleId, reduceId, dataFile, indexFile, metaFile); - } + return newAppShufflePartitionInfo(appShuffleInfo.appId, shuffleId, shuffleMergeId, + reduceId, dataFile, indexFile, metaFile); } catch (IOException e) { logger.error( "Cannot create merged shuffle partition with data file {}, index file {}, and " @@ -179,7 +218,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { indexFile.getAbsolutePath(), metaFile.getAbsolutePath()); throw new RuntimeException( String.format("Cannot initialize merged shuffle partition for appId %s shuffleId %s " - + "reduceId %s", appShuffleInfo.appId, shuffleId, reduceId), e); + + "shuffleMergeId %s reduceId %s", appShuffleInfo.appId, shuffleId, shuffleMergeId, + reduceId), e); } }); } @@ -188,19 +228,31 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { AppShufflePartitionInfo newAppShufflePartitionInfo( String appId, int shuffleId, + int shuffleMergeId, int reduceId, File dataFile, File indexFile, File metaFile) throws IOException { - return new AppShufflePartitionInfo(appId, shuffleId, reduceId, dataFile, + return new AppShufflePartitionInfo(appId, shuffleId, shuffleMergeId, reduceId, dataFile, new MergeShuffleFile(indexFile), new MergeShuffleFile(metaFile)); } @Override - public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) { + public MergedBlockMeta getMergedBlockMeta( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId) { AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(appId); + AppShuffleMergePartitionsInfo partitionsInfo = appShuffleInfo.shuffles.get(shuffleId); + if (null != partitionsInfo && partitionsInfo.shuffleMergeId > shuffleMergeId) { + throw new RuntimeException(String.format( + "MergedBlockMeta fetch for shuffle %s with shuffleMergeId %s reduceId %s is %s", + shuffleId, shuffleMergeId, reduceId, + ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)); + } File indexFile = - appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId); + appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId); if (!indexFile.exists()) { throw new RuntimeException(String.format( "Merged shuffle index file %s not found", indexFile.getPath())); @@ -208,7 +260,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { int size = (int) indexFile.length(); // First entry is the zero offset int numChunks = (size / Long.BYTES) - 1; - File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId); + File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId); if (!metaFile.exists()) { throw new RuntimeException(String.format("Merged shuffle meta file %s not found", metaFile.getPath())); @@ -216,21 +268,30 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { FileSegmentManagedBuffer chunkBitMaps = new FileSegmentManagedBuffer(conf, metaFile, 0L, metaFile.length()); logger.trace( - "{} shuffleId {} reduceId {} num chunks {}", appId, shuffleId, reduceId, numChunks); + "{} shuffleId {} shuffleMergeId {} reduceId {} num chunks {}", + appId, shuffleId, shuffleMergeId, reduceId, numChunks); return new MergedBlockMeta(numChunks, chunkBitMaps); } @SuppressWarnings("UnstableApiUsage") @Override - public ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId) { + public ManagedBuffer getMergedBlockData( + String appId, int shuffleId, int shuffleMergeId, int reduceId, int chunkId) { AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(appId); - File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, reduceId); + AppShuffleMergePartitionsInfo partitionsInfo = appShuffleInfo.shuffles.get(shuffleId); + if (null != partitionsInfo && partitionsInfo.shuffleMergeId > shuffleMergeId) { + throw new RuntimeException(String.format( + "MergedBlockData fetch for shuffle %s with shuffleMergeId %s reduceId %s is %s", + shuffleId, shuffleMergeId, reduceId, + ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)); + } + File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId); if (!dataFile.exists()) { throw new RuntimeException(String.format("Merged shuffle data file %s not found", dataFile.getPath())); } File indexFile = - appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId); + appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId); try { // If we get here, the merged shuffle file should have been properly finalized. Thus we can // use the file length to determine the size of the merged shuffle block. @@ -270,18 +331,32 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { void closeAndDeletePartitionFilesIfNeeded( AppShuffleInfo appShuffleInfo, boolean cleanupLocalDirs) { - for (Map partitionMap : appShuffleInfo.partitions.values()) { - for (AppShufflePartitionInfo partitionInfo : partitionMap.values()) { + appShuffleInfo.shuffles.forEach((shuffleId, shuffleInfo) -> shuffleInfo.shuffleMergePartitions + .forEach((shuffleMergeId, partitionInfo) -> { synchronized (partitionInfo) { - partitionInfo.closeAllFiles(); + partitionInfo.closeAllFilesAndDeleteIfNeeded(false); } - } - } + })); if (cleanupLocalDirs) { deleteExecutorDirs(appShuffleInfo); } } + /** + * Clean up all the AppShufflePartitionInfo for a specific shuffleMergeId. This is done + * since there is a higher shuffleMergeId request made for a shuffleId, therefore clean + * up older shuffleMergeId partitions. The cleanup will be executed in a separate thread. + */ + @VisibleForTesting + void closeAndDeletePartitionFiles(Map partitions) { + partitions + .forEach((partitionId, partitionInfo) -> { + synchronized (partitionInfo) { + partitionInfo.closeAllFilesAndDeleteIfNeeded(true); + } + }); + } + /** * Serially delete local dirs. */ @@ -304,9 +379,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { @Override public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) { AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId); - final String streamId = String.format("%s_%d_%d_%d", - OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, msg.shuffleId, msg.mapIndex, - msg.reduceId); + final String streamId = String.format("%s_%d_%d_%d_%d", + OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, msg.shuffleId, msg.shuffleMergeId, + msg.mapIndex, msg.reduceId); if (appShuffleInfo.attemptId != msg.appAttemptId) { // If this Block belongs to a former application attempt, it is considered late, // as only the blocks from the current application attempt will be merged @@ -317,12 +392,19 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { msg.appAttemptId, appShuffleInfo.attemptId, msg.appId)); } // Retrieve merged shuffle file metadata - AppShufflePartitionInfo partitionInfoBeforeCheck = - getOrCreateAppShufflePartitionInfo(appShuffleInfo, msg.shuffleId, msg.reduceId); - // Here partitionInfo will be null in 2 cases: + AppShufflePartitionInfo partitionInfoBeforeCheck; + try { + partitionInfoBeforeCheck = getOrCreateAppShufflePartitionInfo(appShuffleInfo, msg.shuffleId, + msg.shuffleMergeId, msg.reduceId); + } catch(StaleBlockPushException sbp) { + // Set partitionInfoBeforeCheck to null so that stale block push gets handled. + partitionInfoBeforeCheck = null; + } + // Here partitionInfo will be null in 3 cases: // 1) The request is received for a block that has already been merged, this is possible due // to the retry logic. // 2) The request is received after the merged shuffle is finalized, thus is too late. + // 3) The request is received for a older shuffleMergeId, therefore the block push is rejected. // // For case 1, we will drain the data in the channel and just respond success // to the client. This is required because the response of the previously merged @@ -345,6 +427,13 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { // to notify the client of the failure, so that it can properly halt pushing the remaining // blocks upon receiving such failures to preserve resources on the server/client side. // + // For case 3, we will also drain the data in the channel, but throw an exception in + // {@link org.apache.spark.network.client.StreamCallback#onComplete(String)}. This way, + // the client will be notified of the failure but the channel will remain active. It is + // important to notify the client of the failure, so that it can properly halt pushing the + // remaining blocks upon receiving such failures to preserve resources on the server/client + // side. + // // Speculative execution would also raise a possible scenario with duplicate blocks. Although // speculative execution would kill the slower task attempt, leading to only 1 task attempt // succeeding in the end, there is no guarantee that only one copy of the block will be @@ -353,18 +442,19 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { // getting killed. When this happens, we need to distinguish the duplicate blocks as they // arrive. More details on this is explained in later comments. - // Track if the block is received after shuffle merge finalize - final boolean isTooLate = partitionInfoBeforeCheck == null; - // Check if the given block is already merged by checking the bitmap against the given map index - final AppShufflePartitionInfo partitionInfo = partitionInfoBeforeCheck != null - && partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null - : partitionInfoBeforeCheck; + // Track if the block is received after shuffle merge finalized or from an older + // shuffleMergeId attempt. + final boolean isStaleBlockOrTooLate = partitionInfoBeforeCheck == null; + // Check if the given block is already merged by checking the bitmap against the given map + // index + final AppShufflePartitionInfo partitionInfo = isStaleBlockOrTooLate ? null : + partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null : partitionInfoBeforeCheck; if (partitionInfo != null) { return new PushBlockStreamCallback( this, appShuffleInfo, streamId, partitionInfo, msg.mapIndex); } else { - // For a duplicate block or a block which is late, respond back with a callback that handles - // them differently. + // For a duplicate block or a block which is late or stale block from an older + // shuffleMergeId, respond back with a callback that handles them differently. return new StreamCallbackWithID() { @Override public String getID() { @@ -379,11 +469,11 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { @Override public void onComplete(String streamId) { - if (isTooLate) { + if (isStaleBlockOrTooLate) { // Throw an exception here so the block data is drained from channel and server // responds RpcFailure to the client. throw new RuntimeException(String.format("Block %s %s", streamId, - ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)); + ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)); } // For duplicate block that is received before the shuffle merge finalizes, the // server should respond success to the client. @@ -397,9 +487,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } @Override - public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException { - logger.info("Finalizing shuffle {} from Application {}_{}.", - msg.shuffleId, msg.appId, msg.appAttemptId); + public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) { + logger.info("Finalizing shuffle {} with shuffleMergeId {} from Application {}_{}.", + msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId); AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId); if (appShuffleInfo.attemptId != msg.appAttemptId) { // If this Block belongs to a former application attempt, it is considered late, @@ -410,17 +500,42 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { + "with the current attempt id %s stored in shuffle service for application %s", msg.appAttemptId, appShuffleInfo.attemptId, msg.appId)); } - Map shufflePartitions = - appShuffleInfo.partitions.remove(msg.shuffleId); - MergeStatuses mergeStatuses; - if (shufflePartitions == null || shufflePartitions.isEmpty()) { - mergeStatuses = - new MergeStatuses(msg.shuffleId, new RoaringBitmap[0], new int[0], new long[0]); + AtomicReference> shuffleMergePartitionsRef = + new AtomicReference<>(null); + // Metadata of the determinate stage shuffle can be safely removed as part of finalizing + // shuffle merge. Currently once the shuffle is finalized for a determinate stages, retry + // stages of the same shuffle will have shuffle push disabled. + if (msg.shuffleMergeId == DETERMINATE_SHUFFLE_MERGE_ID) { + AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo = + appShuffleInfo.shuffles.remove(msg.shuffleId); + if (appShuffleMergePartitionsInfo != null) { + shuffleMergePartitionsRef.set(appShuffleMergePartitionsInfo.shuffleMergePartitions); + } } else { - List bitmaps = new ArrayList<>(shufflePartitions.size()); - List reduceIds = new ArrayList<>(shufflePartitions.size()); - List sizes = new ArrayList<>(shufflePartitions.size()); - for (AppShufflePartitionInfo partition: shufflePartitions.values()) { + appShuffleInfo.shuffles.compute(msg.shuffleId, (id, value) -> { + if (null == value || msg.shuffleMergeId != value.shuffleMergeId || + INDETERMINATE_SHUFFLE_FINALIZED == value.shuffleMergePartitions) { + throw new RuntimeException(String.format( + "Shuffle merge finalize request for shuffle %s with" + " shuffleMergeId %s is %s", + msg.shuffleId, msg.shuffleMergeId, + ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE_SUFFIX)); + } else { + shuffleMergePartitionsRef.set(value.shuffleMergePartitions); + return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true); + } + }); + } + Map shuffleMergePartitions = shuffleMergePartitionsRef.get(); + MergeStatuses mergeStatuses; + if (null == shuffleMergePartitions || shuffleMergePartitions.isEmpty()) { + mergeStatuses = + new MergeStatuses(msg.shuffleId, msg.shuffleMergeId, + new RoaringBitmap[0], new int[0], new long[0]); + } else { + List bitmaps = new ArrayList<>(shuffleMergePartitions.size()); + List reduceIds = new ArrayList<>(shuffleMergePartitions.size()); + List sizes = new ArrayList<>(shuffleMergePartitions.size()); + for (AppShufflePartitionInfo partition: shuffleMergePartitions.values()) { synchronized (partition) { try { // This can throw IOException which will marks this shuffle partition as not merged. @@ -432,16 +547,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { logger.warn("Exception while finalizing shuffle partition {}_{} {} {}", msg.appId, msg.appAttemptId, msg.shuffleId, partition.reduceId, ioe); } finally { - partition.closeAllFiles(); + partition.closeAllFilesAndDeleteIfNeeded(false); } } } - mergeStatuses = new MergeStatuses(msg.shuffleId, + mergeStatuses = new MergeStatuses(msg.shuffleId, msg.shuffleMergeId, bitmaps.toArray(new RoaringBitmap[bitmaps.size()]), Ints.toArray(reduceIds), Longs.toArray(sizes)); } - logger.info("Finalized shuffle {} from Application {}_{}.", - msg.shuffleId, msg.appId, msg.appAttemptId); + logger.info("Finalized shuffle {} with shuffleMergeId {} from Application {}_{}.", + msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId); return mergeStatuses; } @@ -563,9 +678,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { private void writeBuf(ByteBuffer buf) throws IOException { while (buf.hasRemaining()) { long updatedPos = partitionInfo.getDataFilePos() + length; - logger.debug("{} shuffleId {} reduceId {} current pos {} updated pos {}", - partitionInfo.appId, partitionInfo.shuffleId, - partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos); + logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} current pos" + + " {} updated pos {}", partitionInfo.appId, partitionInfo.shuffleId, + partitionInfo.shuffleMergeId, partitionInfo.reduceId, + partitionInfo.getDataFilePos(), updatedPos); length += partitionInfo.dataChannel.write(buf, updatedPos); } } @@ -631,6 +747,23 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { abortIfNecessary(); } + /** + * If appShuffleMergePartitionsInfo is null or shuffleMergePartitions is set to + * INDETERMINATE_SHUFFLE_FINALIZED or if the reduceId is not in the map then the + * shuffle is already finalized. Therefore the block push is too late. If + * appShuffleMergePartitionsInfo's shuffleMergeId is + * greater than the request shuffleMergeId then it is a stale block push. + */ + private boolean isStaleOrTooLate( + AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo, + int shuffleMergeId, + int reduceId) { + return null == appShuffleMergePartitionsInfo || + INDETERMINATE_SHUFFLE_FINALIZED == appShuffleMergePartitionsInfo.shuffleMergePartitions || + appShuffleMergePartitionsInfo.shuffleMergeId > shuffleMergeId || + !appShuffleMergePartitionsInfo.shuffleMergePartitions.containsKey(reduceId); + } + @Override public void onData(String streamId, ByteBuffer buf) throws IOException { // When handling the block data using StreamInterceptor, it can help to reduce the amount @@ -648,14 +781,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { // to disk as well. This way, we avoid having to buffer the entirety of every blocks in // memory, while still providing the necessary guarantee. synchronized (partitionInfo) { - Map shufflePartitions = - appShuffleInfo.partitions.get(partitionInfo.shuffleId); - // If the partitionInfo corresponding to (appId, shuffleId, reduceId) is no longer present - // then it means that the shuffle merge has already been finalized. We should thus ignore - // the data and just drain the remaining bytes of this message. This check should be - // placed inside the synchronized block to make sure that checking the key is still - // present and processing the data is atomic. - if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) { + if (isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId), + partitionInfo.shuffleMergeId, partitionInfo.reduceId)) { deferredBufs = null; return; } @@ -668,8 +795,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { return; } abortIfNecessary(); - logger.trace("{} shuffleId {} reduceId {} onData writable", - partitionInfo.appId, partitionInfo.shuffleId, + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData writable", + partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId, partitionInfo.reduceId); if (partitionInfo.getCurrentMapIndex() < 0) { partitionInfo.setCurrentMapIndex(mapIndex); @@ -690,8 +817,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { throw ioe; } } else { - logger.trace("{} shuffleId {} reduceId {} onData deferred", - partitionInfo.appId, partitionInfo.shuffleId, + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData deferred", + partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId, partitionInfo.reduceId); // If we cannot write to disk, we buffer the current block chunk in memory so it could // potentially be written to disk later. We take our best effort without guarantee @@ -725,19 +852,21 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { @Override public void onComplete(String streamId) throws IOException { synchronized (partitionInfo) { - logger.trace("{} shuffleId {} reduceId {} onComplete invoked", - partitionInfo.appId, partitionInfo.shuffleId, + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onComplete invoked", + partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId, partitionInfo.reduceId); - Map shufflePartitions = - appShuffleInfo.partitions.get(partitionInfo.shuffleId); - // When this request initially got to the server, the shuffle merge finalize request - // was not received yet. By the time we finish reading this message, the shuffle merge - // however is already finalized. We should thus respond RpcFailure to the client. - if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) { + // Initially when this request got to the server, the shuffle merge finalize request + // was not received yet or this was the latest stage attempt (or latest shuffleMergeId) + // generating shuffle output for the shuffle ID. By the time we finish reading this + // message, the block request is either stale or too late. We should thus respond + // RpcFailure to the client. + if (isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId), + partitionInfo.shuffleMergeId, partitionInfo.reduceId)) { deferredBufs = null; - throw new RuntimeException(String.format("Block %s %s", streamId, - ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)); + throw new RuntimeException(String.format("Block %s is %s", streamId, + ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)); } + // Check if we can commit this block if (allowedToWrite()) { // Identify duplicate block generated by speculative tasks. We respond success to @@ -800,21 +929,20 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { logger.debug("Encountered issue when merging {}", streamId, throwable); } // Only update partitionInfo if the failure corresponds to a valid request. If the - // request is too late, i.e. received after shuffle merge finalize, #onFailure will - // also be triggered, and we can just ignore. Also, if we couldn't find an opportunity - // to write the block data to disk, we should also ignore here. + // request is too late, i.e. received after shuffle merge finalize or stale block push, + // #onFailure will also be triggered, and we can just ignore. Also, if we couldn't find + // an opportunity to write the block data to disk, we should also ignore here. if (isWriting) { synchronized (partitionInfo) { - Map shufflePartitions = - appShuffleInfo.partitions.get(partitionInfo.shuffleId); - if (shufflePartitions != null && shufflePartitions.containsKey(partitionInfo.reduceId)) { - logger.debug("{} shuffleId {} reduceId {} encountered failure", - partitionInfo.appId, partitionInfo.shuffleId, - partitionInfo.reduceId); - partitionInfo.setCurrentMapIndex(-1); + if (!isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId), + partitionInfo.shuffleMergeId, partitionInfo.reduceId)) { + logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {}" + + " encountered failure", partitionInfo.appId, partitionInfo.shuffleId, + partitionInfo.shuffleMergeId, partitionInfo.reduceId); + partitionInfo.setCurrentMapIndex(-1); + } } } - } isWriting = false; } @@ -824,12 +952,35 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } } + /** + * Wrapper class to hold merged Shuffle related information for a specific shuffleMergeId + * required for the shuffles of indeterminate stages. + */ + public static class AppShuffleMergePartitionsInfo { + private final int shuffleMergeId; + private final Map shuffleMergePartitions; + + public AppShuffleMergePartitionsInfo( + int shuffleMergeId, boolean shuffleFinalized) { + this.shuffleMergeId = shuffleMergeId; + this.shuffleMergePartitions = shuffleFinalized ? + INDETERMINATE_SHUFFLE_FINALIZED : new ConcurrentHashMap<>(); + } + + @VisibleForTesting + public Map getShuffleMergePartitions() { + return shuffleMergePartitions; + } + } + /** Metadata tracked for an actively merged shuffle partition */ public static class AppShufflePartitionInfo { private final String appId; private final int shuffleId; + private final int shuffleMergeId; private final int reduceId; + private final File dataFile; // The merged shuffle data file channel public final FileChannel dataChannel; // The index file for a particular merged shuffle contains the chunk offsets. @@ -854,6 +1005,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { AppShufflePartitionInfo( String appId, int shuffleId, + int shuffleMergeId, int reduceId, File dataFile, MergeShuffleFile indexFile, @@ -861,8 +1013,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { Preconditions.checkArgument(appId != null, "app id is null"); this.appId = appId; this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; this.reduceId = reduceId; this.dataChannel = new FileOutputStream(dataFile).getChannel(); + this.dataFile = dataFile; this.indexFile = indexFile; this.metaFile = metaFile; this.currentMapIndex = -1; @@ -878,8 +1032,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } public void setDataFilePos(long dataFilePos) { - logger.trace("{} shuffleId {} reduceId {} current pos {} update pos {}", appId, - shuffleId, reduceId, this.dataFilePos, dataFilePos); + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} current pos {}" + + " update pos {}", appId, shuffleId, shuffleMergeId, reduceId, this.dataFilePos, + dataFilePos); this.dataFilePos = dataFilePos; } @@ -888,8 +1043,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } void setCurrentMapIndex(int mapIndex) { - logger.trace("{} shuffleId {} reduceId {} updated mapIndex {} current mapIndex {}", - appId, shuffleId, reduceId, currentMapIndex, mapIndex); + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} updated mapIndex {}" + + " current mapIndex {}", appId, shuffleId, shuffleMergeId, reduceId, + currentMapIndex, mapIndex); this.currentMapIndex = mapIndex; } @@ -898,8 +1054,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } void blockMerged(int mapIndex) { - logger.debug("{} shuffleId {} reduceId {} updated merging mapIndex {}", appId, - shuffleId, reduceId, mapIndex); + logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} updated merging mapIndex {}", + appId, shuffleId, shuffleMergeId, reduceId, mapIndex); mapTracker.add(mapIndex); chunkTracker.add(mapIndex); lastMergedMapIndex = mapIndex; @@ -917,8 +1073,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { */ void updateChunkInfo(long chunkOffset, int mapIndex) throws IOException { try { - logger.trace("{} shuffleId {} reduceId {} index current {} updated {}", - appId, shuffleId, reduceId, this.lastChunkOffset, chunkOffset); + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} index current {}" + + " updated {}", appId, shuffleId, shuffleMergeId, reduceId, + this.lastChunkOffset, chunkOffset); if (indexMetaUpdateFailed) { indexFile.getChannel().position(indexFile.getPos()); } @@ -946,8 +1103,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { return; } chunkTracker.add(mapIndex); - logger.trace("{} shuffleId {} reduceId {} mapIndex {} write chunk to meta file", - appId, shuffleId, reduceId, mapIndex); + logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} mapIndex {}" + + " write chunk to meta file", appId, shuffleId, shuffleMergeId, reduceId, mapIndex); if (indexMetaUpdateFailed) { metaFile.getChannel().position(metaFile.getPos()); } @@ -980,32 +1137,41 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { metaFile.getChannel().truncate(metaFile.getPos()); } - void closeAllFiles() { + void closeAllFilesAndDeleteIfNeeded(boolean delete) { try { if (dataChannel.isOpen()) { dataChannel.close(); + if (delete) { + dataFile.delete(); + } } } catch (IOException ioe) { - logger.warn("Error closing data channel for {} shuffleId {} reduceId {}", - appId, shuffleId, reduceId); + logger.warn("Error closing data channel for {} shuffleId {} shuffleMergeId {}" + + " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId); } try { metaFile.close(); + if (delete) { + metaFile.delete(); + } } catch (IOException ioe) { - logger.warn("Error closing meta file for {} shuffleId {} reduceId {}", - appId, shuffleId, reduceId); - } + logger.warn("Error closing meta file for {} shuffleId {} shuffleMergeId {}" + + " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId); + } try { indexFile.close(); + if (delete) { + indexFile.delete(); + } } catch (IOException ioe) { - logger.warn("Error closing index file for {} shuffleId {} reduceId {}", - appId, shuffleId, reduceId); + logger.warn("Error closing index file for {} shuffleId {} shuffleMergeId {}" + + " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId); } } @Override protected void finalize() throws Throwable { - closeAllFiles(); + closeAllFilesAndDeleteIfNeeded(false); } @VisibleForTesting @@ -1065,7 +1231,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { private final String appId; private final int attemptId; private final AppPathsInfo appPathsInfo; - private final ConcurrentMap> partitions; + /** + * 1. Key tracks shuffleId for an application + * 2. Value tracks the AppShuffleMergePartitionsInfo having shuffleMergeId and + * a Map tracking AppShufflePartitionInfo for all the shuffle partitions. + */ + private final ConcurrentMap shuffles; AppShuffleInfo( String appId, @@ -1074,12 +1245,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { this.appId = appId; this.attemptId = attemptId; this.appPathsInfo = appPathsInfo; - partitions = new ConcurrentHashMap<>(); + shuffles = new ConcurrentHashMap<>(); } @VisibleForTesting - public ConcurrentMap> getPartitions() { - return partitions; + public ConcurrentMap getShuffles() { + return shuffles; } /** @@ -1098,29 +1269,37 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { private String generateFileName( String appId, int shuffleId, + int shuffleMergeId, int reduceId) { return String.format( - "%s_%s_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, appId, shuffleId, reduceId); + "%s_%s_%d_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, appId, shuffleId, + shuffleMergeId, reduceId); } public File getMergedShuffleDataFile( int shuffleId, + int shuffleMergeId, int reduceId) { - String fileName = String.format("%s.data", generateFileName(appId, shuffleId, reduceId)); + String fileName = String.format("%s.data", generateFileName(appId, shuffleId, + shuffleMergeId, reduceId)); return getFile(fileName); } public File getMergedShuffleIndexFile( int shuffleId, + int shuffleMergeId, int reduceId) { - String indexName = String.format("%s.index", generateFileName(appId, shuffleId, reduceId)); + String indexName = String.format("%s.index", generateFileName(appId, shuffleId, + shuffleMergeId, reduceId)); return getFile(indexName); } public File getMergedShuffleMetaFile( int shuffleId, + int shuffleMergeId, int reduceId) { - String metaName = String.format("%s.meta", generateFileName(appId, shuffleId, reduceId)); + String metaName = String.format("%s.meta", generateFileName(appId, shuffleId, + shuffleMergeId, reduceId)); return getFile(metaName); } } @@ -1130,18 +1309,21 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { private final FileChannel channel; private final DataOutputStream dos; private long pos; + private File file; @VisibleForTesting MergeShuffleFile(File file) throws IOException { FileOutputStream fos = new FileOutputStream(file); channel = fos.getChannel(); dos = new DataOutputStream(fos); + this.file = file; } @VisibleForTesting MergeShuffleFile(FileChannel channel, DataOutputStream dos) { this.channel = channel; this.dos = dos; + this.file = null; } private void updatePos(long numBytes) { @@ -1154,6 +1336,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { } } + void delete() throws IOException { + try { + if (null != file) { + file.delete(); + } + } finally { + file = null; + } + } + @VisibleForTesting DataOutputStream getDos() { return dos; @@ -1169,4 +1361,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { return pos; } } + + public static class StaleBlockPushException extends RuntimeException { + public StaleBlockPushException(String message) { + super(message); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java index 27345dd8e7..cf4cbcf1ed 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java @@ -37,14 +37,19 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { public final int[] reduceIds; // The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds. public final int[][] chunkIds; + // shuffleMergeId is used to uniquely identify merging process of shuffle by + // an indeterminate stage attempt. + public final int shuffleMergeId; public FetchShuffleBlockChunks( String appId, String execId, int shuffleId, + int shuffleMergeId, int[] reduceIds, int[][] chunkIds) { super(appId, execId, shuffleId); + this.shuffleMergeId = shuffleMergeId; this.reduceIds = reduceIds; this.chunkIds = chunkIds; assert(reduceIds.length == chunkIds.length); @@ -56,6 +61,7 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { @Override public String toString() { return toStringHelper() + .append("shuffleMergeId", shuffleMergeId) .append("reduceIds", Arrays.toString(reduceIds)) .append("chunkIds", Arrays.deepToString(chunkIds)) .toString(); @@ -68,13 +74,16 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o; if (!super.equals(that)) return false; - if (!Arrays.equals(reduceIds, that.reduceIds)) return false; + if (shuffleMergeId != that.shuffleMergeId || + !Arrays.equals(reduceIds, that.reduceIds)) { + return false; + } return Arrays.deepEquals(chunkIds, that.chunkIds); } @Override public int hashCode() { - int result = super.hashCode(); + int result = super.hashCode() * 31 + shuffleMergeId; result = 31 * result + Arrays.hashCode(reduceIds); result = 31 * result + Arrays.deepHashCode(chunkIds); return result; @@ -89,12 +98,14 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { return super.encodedLength() + Encoders.IntArrays.encodedLength(reduceIds) + 4 /* encoded length of chunkIds.size() */ + + 4 /* encoded length of shuffleMergeId */ + encodedLengthOfChunkIds; } @Override public void encode(ByteBuf buf) { super.encode(buf); + buf.writeInt(shuffleMergeId); Encoders.IntArrays.encode(buf, reduceIds); // Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the // interest of forward compatibility. @@ -117,12 +128,14 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); + int shuffleMergeId = buf.readInt(); int[] reduceIds = Encoders.IntArrays.decode(buf); int chunkIdsLen = buf.readInt(); int[][] chunkIds = new int[chunkIdsLen][]; for (int i = 0; i < chunkIdsLen; i++) { chunkIds[i] = Encoders.IntArrays.decode(buf); } - return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds); + return new FetchShuffleBlockChunks(appId, execId, shuffleId, shuffleMergeId, reduceIds, + chunkIds); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java index 088ff388ea..675739a41e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java @@ -34,14 +34,17 @@ public class FinalizeShuffleMerge extends BlockTransferMessage { public final String appId; public final int appAttemptId; public final int shuffleId; + public final int shuffleMergeId; public FinalizeShuffleMerge( String appId, int appAttemptId, - int shuffleId) { + int shuffleId, + int shuffleMergeId) { this.appId = appId; this.appAttemptId = appAttemptId; this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; } @Override @@ -51,7 +54,7 @@ public class FinalizeShuffleMerge extends BlockTransferMessage { @Override public int hashCode() { - return Objects.hashCode(appId, appAttemptId, shuffleId); + return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId); } @Override @@ -60,6 +63,7 @@ public class FinalizeShuffleMerge extends BlockTransferMessage { .append("appId", appId) .append("attemptId", appAttemptId) .append("shuffleId", shuffleId) + .append("shuffleMergeId", shuffleMergeId) .toString(); } @@ -69,14 +73,15 @@ public class FinalizeShuffleMerge extends BlockTransferMessage { FinalizeShuffleMerge o = (FinalizeShuffleMerge) other; return Objects.equal(appId, o.appId) && appAttemptId == o.appAttemptId - && shuffleId == o.shuffleId; + && shuffleId == o.shuffleId + && shuffleMergeId == o.shuffleMergeId; } return false; } @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; } @Override @@ -84,12 +89,14 @@ public class FinalizeShuffleMerge extends BlockTransferMessage { Encoders.Strings.encode(buf, appId); buf.writeInt(appAttemptId); buf.writeInt(shuffleId); + buf.writeInt(shuffleMergeId); } public static FinalizeShuffleMerge decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); int attemptId = buf.readInt(); int shuffleId = buf.readInt(); - return new FinalizeShuffleMerge(appId, attemptId, shuffleId); + int shuffleMergeId = buf.readInt(); + return new FinalizeShuffleMerge(appId, attemptId, shuffleId, shuffleMergeId); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java index 142ab73e79..b2658d62b4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java @@ -40,6 +40,11 @@ import org.apache.spark.network.protocol.Encoders; public class MergeStatuses extends BlockTransferMessage { /** Shuffle ID **/ public final int shuffleId; + /** + * shuffleMergeId is used to uniquely identify merging process of shuffle by + * an indeterminate stage attempt. + */ + public final int shuffleMergeId; /** * Array of bitmaps tracking the set of mapper partition blocks merged for each * reducer partition @@ -55,10 +60,12 @@ public class MergeStatuses extends BlockTransferMessage { public MergeStatuses( int shuffleId, + int shuffleMergeId, RoaringBitmap[] bitmaps, int[] reduceIds, long[] sizes) { this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; this.bitmaps = bitmaps; this.reduceIds = reduceIds; this.sizes = sizes; @@ -71,7 +78,8 @@ public class MergeStatuses extends BlockTransferMessage { @Override public int hashCode() { - int objectHashCode = Objects.hashCode(shuffleId); + int objectHashCode = Objects.hashCode(shuffleId) * 41 + + Objects.hashCode(shuffleMergeId); return (objectHashCode * 41 + Arrays.hashCode(reduceIds) * 41 + Arrays.hashCode(bitmaps) * 41 + Arrays.hashCode(sizes)); } @@ -80,6 +88,7 @@ public class MergeStatuses extends BlockTransferMessage { public String toString() { return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) .append("shuffleId", shuffleId) + .append("shuffleMergeId", shuffleMergeId) .append("reduceId size", reduceIds.length) .toString(); } @@ -89,6 +98,7 @@ public class MergeStatuses extends BlockTransferMessage { if (other != null && other instanceof MergeStatuses) { MergeStatuses o = (MergeStatuses) other; return Objects.equal(shuffleId, o.shuffleId) + && Objects.equal(shuffleMergeId, o.shuffleMergeId) && Arrays.equals(bitmaps, o.bitmaps) && Arrays.equals(reduceIds, o.reduceIds) && Arrays.equals(sizes, o.sizes); @@ -98,7 +108,7 @@ public class MergeStatuses extends BlockTransferMessage { @Override public int encodedLength() { - return 4 // int + return 4 + 4 // shuffleId and shuffleMergeId + Encoders.BitmapArrays.encodedLength(bitmaps) + Encoders.IntArrays.encodedLength(reduceIds) + Encoders.LongArrays.encodedLength(sizes); @@ -107,6 +117,7 @@ public class MergeStatuses extends BlockTransferMessage { @Override public void encode(ByteBuf buf) { buf.writeInt(shuffleId); + buf.writeInt(shuffleMergeId); Encoders.BitmapArrays.encode(buf, bitmaps); Encoders.IntArrays.encode(buf, reduceIds); Encoders.LongArrays.encode(buf, sizes); @@ -114,9 +125,10 @@ public class MergeStatuses extends BlockTransferMessage { public static MergeStatuses decode(ByteBuf buf) { int shuffleId = buf.readInt(); + int shuffleMergeId = buf.readInt(); RoaringBitmap[] bitmaps = Encoders.BitmapArrays.decode(buf); int[] reduceIds = Encoders.IntArrays.decode(buf); long[] sizes = Encoders.LongArrays.decode(buf); - return new MergeStatuses(shuffleId, bitmaps, reduceIds, sizes); + return new MergeStatuses(shuffleId, shuffleMergeId, bitmaps, reduceIds, sizes); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java index d5e1cf2464..b868d7ccff 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java @@ -37,6 +37,7 @@ public class PushBlockStream extends BlockTransferMessage { public final String appId; public final int appAttemptId; public final int shuffleId; + public final int shuffleMergeId; public final int mapIndex; public final int reduceId; // Similar to the chunkIndex in StreamChunkId, indicating the index of a block in a batch of @@ -47,12 +48,14 @@ public class PushBlockStream extends BlockTransferMessage { String appId, int appAttemptId, int shuffleId, + int shuffleMergeId, int mapIndex, int reduceId, int index) { this.appId = appId; this.appAttemptId = appAttemptId; this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; this.mapIndex = mapIndex; this.reduceId = reduceId; this.index = index; @@ -65,7 +68,8 @@ public class PushBlockStream extends BlockTransferMessage { @Override public int hashCode() { - return Objects.hashCode(appId, appAttemptId, shuffleId, mapIndex , reduceId, index); + return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId, mapIndex , reduceId, + index); } @Override @@ -74,6 +78,7 @@ public class PushBlockStream extends BlockTransferMessage { .append("appId", appId) .append("attemptId", appAttemptId) .append("shuffleId", shuffleId) + .append("shuffleMergeId", shuffleMergeId) .append("mapIndex", mapIndex) .append("reduceId", reduceId) .append("index", index) @@ -87,6 +92,7 @@ public class PushBlockStream extends BlockTransferMessage { return Objects.equal(appId, o.appId) && appAttemptId == o.appAttemptId && shuffleId == o.shuffleId + && shuffleMergeId == o.shuffleMergeId && mapIndex == o.mapIndex && reduceId == o.reduceId && index == o.index; @@ -96,7 +102,7 @@ public class PushBlockStream extends BlockTransferMessage { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4 + 4 + 4; } @Override @@ -104,6 +110,7 @@ public class PushBlockStream extends BlockTransferMessage { Encoders.Strings.encode(buf, appId); buf.writeInt(appAttemptId); buf.writeInt(shuffleId); + buf.writeInt(shuffleMergeId); buf.writeInt(mapIndex); buf.writeInt(reduceId); buf.writeInt(index); @@ -113,9 +120,11 @@ public class PushBlockStream extends BlockTransferMessage { String appId = Encoders.Strings.decode(buf); int attemptId = buf.readInt(); int shuffleId = buf.readInt(); + int shuffleMergeId = buf.readInt(); int mapIdx = buf.readInt(); int reduceId = buf.readInt(); int index = buf.readInt(); - return new PushBlockStream(appId, attemptId, shuffleId, mapIdx, reduceId, index); + return new PushBlockStream(appId, attemptId, shuffleId, shuffleMergeId, mapIdx, reduceId, + index); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java index 992e7762c5..c8066d1e6b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java @@ -29,23 +29,31 @@ import static org.junit.Assert.*; public class ErrorHandlerSuite { @Test - public void testPushErrorRetry() { - ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler(); - assertFalse(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException( - ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))); - assertFalse(handler.shouldRetryError(new RuntimeException(new ConnectException()))); - assertTrue(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException( + public void testErrorRetry() { + ErrorHandler.BlockPushErrorHandler pushHandler = new ErrorHandler.BlockPushErrorHandler(); + assertFalse(pushHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException( + ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)))); + assertFalse(pushHandler.shouldRetryError(new RuntimeException(new ConnectException()))); + assertTrue(pushHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException( ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))); - assertTrue(handler.shouldRetryError(new Throwable())); + assertTrue(pushHandler.shouldRetryError(new Throwable())); + + ErrorHandler.BlockFetchErrorHandler fetchHandler = new ErrorHandler.BlockFetchErrorHandler(); + assertFalse(fetchHandler.shouldRetryError(new RuntimeException( + ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH))); } @Test - public void testPushErrorLogging() { - ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler(); - assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException( - ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))); - assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException( + public void testErrorLogging() { + ErrorHandler.BlockPushErrorHandler pushHandler = new ErrorHandler.BlockPushErrorHandler(); + assertFalse(pushHandler.shouldLogError(new RuntimeException(new IllegalArgumentException( + ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)))); + assertFalse(pushHandler.shouldLogError(new RuntimeException(new IllegalArgumentException( ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))); - assertTrue(handler.shouldLogError(new Throwable())); + assertTrue(pushHandler.shouldLogError(new Throwable())); + + ErrorHandler.BlockFetchErrorHandler fetchHandler = new ErrorHandler.BlockFetchErrorHandler(); + assertFalse(fetchHandler.shouldLogError(new RuntimeException( + ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH))); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 00756b1b62..9e0b3c65c9 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -243,9 +243,9 @@ public class ExternalBlockHandlerSuite { public void testFinalizeShuffleMerge() throws IOException { RpcResponseCallback callback = mock(RpcResponseCallback.class); - FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 1, 0); + FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 1, 0, 0); RoaringBitmap bitmap = RoaringBitmap.bitmapOf(0, 1, 2); - MergeStatuses statuses = new MergeStatuses(0, new RoaringBitmap[]{bitmap}, + MergeStatuses statuses = new MergeStatuses(0, 0, new RoaringBitmap[]{bitmap}, new int[]{3}, new long[]{30}); when(mergedShuffleManager.finalizeShuffleMerge(req)).thenReturn(statuses); @@ -269,22 +269,22 @@ public class ExternalBlockHandlerSuite { @Test public void testFetchMergedBlocksMeta() { - when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn( + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 0)).thenReturn( new MergedBlockMeta(1, mock(ManagedBuffer.class))); - when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn( + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 1)).thenReturn( new MergedBlockMeta(3, mock(ManagedBuffer.class))); - when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn( + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 2)).thenReturn( new MergedBlockMeta(5, mock(ManagedBuffer.class))); int[] expectedCount = new int[]{1, 3, 5}; String appId = "app0"; long requestId = 0L; for (int reduceId = 0; reduceId < 3; reduceId++) { - MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId); + MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, 0, reduceId); MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); handler.getMergedBlockMetaReqHandler() .receiveMergeBlockMetaReq(client, req, callback); - verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId); + verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, 0, reduceId); ArgumentCaptor numChunksResponse = ArgumentCaptor.forClass(Integer.class); ArgumentCaptor chunkBitmapResponse = @@ -313,12 +313,12 @@ public class ExternalBlockHandlerSuite { if (useOpenBlocks) { OpenBlocks openBlocks = new OpenBlocks("app0", "exec1", - new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0", - "shuffleChunk_0_1_1"}); + new String[] {"shuffleChunk_0_0_0_0", "shuffleChunk_0_0_0_1", "shuffleChunk_0_0_1_0", + "shuffleChunk_0_0_1_1"}); buffer = openBlocks.toByteBuffer(); } else { FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks( - "app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}}); + "app0", "exec1", 0, 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}}); buffer = fetchChunks.toByteBuffer(); } ManagedBuffer[][] buffers = new ManagedBuffer[][] { @@ -334,7 +334,7 @@ public class ExternalBlockHandlerSuite { for (int reduceId = 0; reduceId < 2; reduceId++) { for (int chunkId = 0; chunkId < 2; chunkId++) { when(mergedShuffleManager.getMergedBlockData( - "app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]); + "app0", 0, 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]); } } handler.receive(client, buffer, callback); @@ -356,11 +356,12 @@ public class ExternalBlockHandlerSuite { } } assertFalse(bufferIter.hasNext()); - verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt()); + verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt(), + anyInt()); verify(blockResolver, never()).getBlockData( anyString(), anyString(), anyInt(), anyInt(), anyInt()); - verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0); - verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0, 0); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index c4967eab31..d336779c1d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -246,51 +246,49 @@ public class OneForOneBlockFetcherSuite { @Test public void testShuffleBlockChunksFetch() { LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); - blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + blocks.put("shuffleChunk_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffleChunk_0_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("shuffleChunk_0_0_0_2", + new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockFetchingListener listener = fetchBlocks(blocks, blockIds, - new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), conf); for (int i = 0; i < 3; i ++) { - verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i, - blocks.get("shuffleChunk_0_0_" + i)); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_" + i, + blocks.get("shuffleChunk_0_0_0_" + i)); } } @Test public void testShuffleBlockChunkFetchFailure() { LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("shuffleChunk_0_0_1", null); - blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + blocks.put("shuffleChunk_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffleChunk_0_0_0_1", null); + blocks.put("shuffleChunk_0_0_0_2", + new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockFetchingListener listener = fetchBlocks(blocks, blockIds, - new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}), + new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[]{0}, new int[][]{{0, 1, 2}}), conf); - verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0", - blocks.get("shuffleChunk_0_0_0")); - verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any()); - verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2", - blocks.get("shuffleChunk_0_0_2")); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_0", + blocks.get("shuffleChunk_0_0_0_0")); + verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_0_1"), any()); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_2", + blocks.get("shuffleChunk_0_0_0_2")); } @Test public void testInvalidShuffleBlockIds() { assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), new String[]{"shuffle_0_0"}, - new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, - new int[][] {{ 0 }}), conf)); + new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, + new int[][] {{ 0 }}, false), conf)); assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), new String[]{"shuffleChunk_0_0_0_0_0"}, - new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, - new int[][] {{ 0 }}), conf)); - assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), - new String[]{"shuffleChunk_0_0_0_0"}, - new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[] { 0 }, new int[][] {{ 0 }}), conf)); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java index f709a565c0..d2fd5d9be6 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java @@ -45,77 +45,78 @@ public class OneForOneBlockPusherSuite { @Test public void testPushOne() { LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1]))); + blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockPushingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0))); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0))); - verify(listener).onBlockPushSuccess(eq("shufflePush_0_0_0"), any()); + verify(listener).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any()); } @Test public void testPushThree() { LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); - blocks.put("shufflePush_0_2_0", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shufflePush_0_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("shufflePush_0_0_2_0", + new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockPushingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id",0, 0, 0, 0, 0), - new PushBlockStream("app-id", 0, 0, 1, 0, 1), - new PushBlockStream("app-id", 0, 0, 2, 0, 2))); + Arrays.asList(new PushBlockStream("app-id",0, 0, 0, 0, 0, 0), + new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1), + new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2))); - verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any()); - verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_1_0"), any()); - verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_2_0"), any()); + verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any()); + verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_1_0"), any()); + verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_2_0"), any()); } @Test public void testServerFailures() { LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); - blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shufflePush_0_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + blocks.put("shufflePush_0_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockPushingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0), - new PushBlockStream("app-id", 0, 0, 1, 0, 1), - new PushBlockStream("app-id", 0, 0, 2, 0, 2))); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0), + new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1), + new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2))); - verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any()); - verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_1_0"), any()); - verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_2_0"), any()); + verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any()); + verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_1_0"), any()); + verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_2_0"), any()); } @Test public void testHandlingRetriableFailures() { LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("shufflePush_0_1_0", null); - blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shufflePush_0_0_1_0", null); + blocks.put("shufflePush_0_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockPushingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0), - new PushBlockStream("app-id", 0, 0, 1, 0, 1), - new PushBlockStream("app-id", 0, 0, 2, 0, 2))); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0), + new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1), + new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2))); - verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any()); - verify(listener, times(0)).onBlockPushSuccess(not(eq("shufflePush_0_0_0")), any()); - verify(listener, times(0)).onBlockPushFailure(eq("shufflePush_0_0_0"), any()); - verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_1_0"), any()); - verify(listener, times(2)).onBlockPushFailure(eq("shufflePush_0_2_0"), any()); + verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any()); + verify(listener, times(0)).onBlockPushSuccess(not(eq("shufflePush_0_0_0_0")), any()); + verify(listener, times(0)).onBlockPushFailure(eq("shufflePush_0_0_0_0"), any()); + verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_1_0"), any()); + verify(listener, times(2)).onBlockPushFailure(eq("shufflePush_0_0_2_0"), any()); } /** @@ -147,7 +148,7 @@ public class OneForOneBlockPusherSuite { + ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)); } else { callback.onFailure(new RuntimeException("Quick fail " + entry.getKey() - + ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)); + + ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)); } assertEquals(msgIterator.next(), message); return null; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java index 2a73aa56b2..46a1569008 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java @@ -106,7 +106,7 @@ public class RemoteBlockPushResolverSuite { @Test(expected = RuntimeException.class) public void testNoIndexFile() { try { - pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); } catch (Throwable t) { assertTrue(t.getMessage().startsWith("Merged shuffle index file")); Throwables.propagate(t); @@ -116,58 +116,58 @@ public class RemoteBlockPushResolverSuite { @Test public void testBasicBlockMerge() throws IOException { PushBlock[] pushBlocks = new PushBlock[] { - new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[4])), - new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[5])) + new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[4])), + new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[5])) }; pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); validateMergeStatuses(statuses, new int[] {0}, new long[] {9}); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}}); } @Test public void testDividingMergedBlocksIntoChunks() throws IOException { PushBlock[] pushBlocks = new PushBlock[] { - new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])), - new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])), - new PushBlock(0, 2, 0, ByteBuffer.wrap(new byte[5])), - new PushBlock(0, 3, 0, ByteBuffer.wrap(new byte[3])) + new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[2])), + new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[3])), + new PushBlock(0, 0, 2, 0, ByteBuffer.wrap(new byte[5])), + new PushBlock(0, 0, 3, 0, ByteBuffer.wrap(new byte[3])) }; pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); validateMergeStatuses(statuses, new int[] {0}, new long[] {13}); - MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, meta, new int[]{5, 5, 3}, new int[][]{{0, 1}, {2}, {3}}); + MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, meta, new int[]{5, 5, 3}, new int[][]{{0, 1}, {2}, {3}}); } @Test public void testFinalizeWithMultipleReducePartitions() throws IOException { PushBlock[] pushBlocks = new PushBlock[] { - new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])), - new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])), - new PushBlock(0, 0, 1, ByteBuffer.wrap(new byte[5])), - new PushBlock(0, 1, 1, ByteBuffer.wrap(new byte[3])) + new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[2])), + new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[3])), + new PushBlock(0, 0, 0, 1, ByteBuffer.wrap(new byte[5])), + new PushBlock(0, 0, 1, 1, ByteBuffer.wrap(new byte[3])) }; pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); validateMergeStatuses(statuses, new int[] {0, 1}, new long[] {5, 8}); - MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, meta, new int[]{5}, new int[][]{{0, 1}}); + MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, meta, new int[]{5}, new int[][]{{0, 1}}); } @Test public void testDeferredBufsAreWrittenDuringOnData() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); // This should be deferred stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); // stream 1 now completes @@ -176,20 +176,20 @@ public class RemoteBlockPushResolverSuite { // stream 2 has more data and then completes stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); stream2.onComplete(stream2.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}}); } @Test public void testDeferredBufsAreWrittenDuringOnComplete() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); // This should be deferred stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); @@ -198,40 +198,40 @@ public class RemoteBlockPushResolverSuite { stream1.onComplete(stream1.getID()); // stream 2 now completes completes stream2.onComplete(stream2.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}}); } @Test public void testDuplicateBlocksAreIgnoredWhenPrevStreamHasCompleted() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); stream1.onComplete(stream1.getID()); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); // This should be ignored stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onComplete(stream2.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); } @Test public void testDuplicateBlocksAreIgnoredWhenPrevStreamIsInProgress() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); // This should be ignored stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); @@ -240,20 +240,20 @@ public class RemoteBlockPushResolverSuite { stream1.onComplete(stream1.getID()); // stream 2 now completes completes stream2.onComplete(stream2.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); } @Test public void testFailureAfterData() throws IOException { StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4])); stream.onFailure(stream.getID(), new RuntimeException("Forced Failure")); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); assertEquals("num-chunks", 0, blockMeta.getNumChunks()); } @@ -261,13 +261,13 @@ public class RemoteBlockPushResolverSuite { public void testFailureAfterMultipleDataBlocks() throws IOException { StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[2])); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[3])); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4])); stream.onFailure(stream.getID(), new RuntimeException("Forced Failure")); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); assertEquals("num-chunks", 0, blockMeta.getNumChunks()); } @@ -275,15 +275,15 @@ public class RemoteBlockPushResolverSuite { public void testFailureAfterComplete() throws IOException { StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[2])); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[3])); stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4])); stream.onComplete(stream.getID()); stream.onFailure(stream.getID(), new RuntimeException("Forced Failure")); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); } @Test(expected = RuntimeException.class) @@ -293,22 +293,23 @@ public class RemoteBlockPushResolverSuite { ByteBuffer.wrap(new byte[5]) }; StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); for (ByteBuffer block : blocks) { stream.onData(stream.getID(), block); } stream.onComplete(stream.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); try { stream1.onComplete(stream1.getID()); } catch (RuntimeException re) { - assertEquals( - "Block shufflePush_0_1_0 received after merged shuffle is finalized", re.getMessage()); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); + assertEquals("Block shufflePush_0_0_1_0 received after merged shuffle is finalized or stale" + + " block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being" + + " pushed", re.getMessage()); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); throw re; } } @@ -321,7 +322,7 @@ public class RemoteBlockPushResolverSuite { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); byte[] data = new byte[10]; ThreadLocalRandom.current().nextBytes(data); stream1.onData(stream1.getID(), ByteBuffer.wrap(data)); @@ -329,21 +330,21 @@ public class RemoteBlockPushResolverSuite { stream1.onFailure(stream1.getID(), new RuntimeException("forced error")); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); ByteBuffer nextBuf= ByteBuffer.wrap(expectedBytes, 0, 2); stream2.onData(stream2.getID(), nextBuf); stream2.onComplete(stream2.getID()); StreamCallbackWithID stream3 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); nextBuf = ByteBuffer.wrap(expectedBytes, 2, 2); stream3.onData(stream3.getID(), nextBuf); stream3.onComplete(stream3.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{1, 2}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{1, 2}}); FileSegmentManagedBuffer mb = - (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(TEST_APP, 0, 0, 0); + (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(TEST_APP, 0, 0, 0, 0); assertArrayEquals(expectedBytes, mb.nioByteBuffer().array()); } @@ -351,11 +352,11 @@ public class RemoteBlockPushResolverSuite { public void testCollision() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); // This should be deferred stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); // Since stream2 didn't get any opportunity it will throw couldn't find opportunity error @@ -363,7 +364,7 @@ public class RemoteBlockPushResolverSuite { stream2.onComplete(stream2.getID()); } catch (RuntimeException re) { assertEquals( - "Couldn't find an opportunity to write block shufflePush_0_1_0 to merged shuffle", + "Couldn't find an opportunity to write block shufflePush_0_0_1_0 to merged shuffle", re.getMessage()); throw re; } @@ -373,16 +374,16 @@ public class RemoteBlockPushResolverSuite { public void testFailureInAStreamDoesNotInterfereWithStreamWhichIsWriting() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); // There is a failure with stream2 stream2.onFailure(stream2.getID(), new RuntimeException("forced error")); StreamCallbackWithID stream3 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); // This should be deferred stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[5])); // Since this stream didn't get any opportunity it will throw couldn't find opportunity error @@ -391,7 +392,7 @@ public class RemoteBlockPushResolverSuite { stream3.onComplete(stream3.getID()); } catch (RuntimeException re) { assertEquals( - "Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle", + "Couldn't find an opportunity to write block shufflePush_0_0_2_0 to merged shuffle", re.getMessage()); failedEx = re; } @@ -399,9 +400,9 @@ public class RemoteBlockPushResolverSuite { stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); stream1.onComplete(stream1.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}}); if (failedEx != null) { throw failedEx; } @@ -502,11 +503,11 @@ public class RemoteBlockPushResolverSuite { Path[] activeDirs = createLocalDirs(1); registerExecutor(testApp, prepareLocalDirs(activeDirs, MERGE_DIRECTORY), MERGE_DIRECTORY_META); PushBlock[] pushBlocks = new PushBlock[] { - new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[4]))}; + new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[4]))}; pushBlockHelper(testApp, NO_ATTEMPT_ID, pushBlocks); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 0); - validateChunks(testApp, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 0, 0); + validateChunks(testApp, 0, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); String[] mergeDirs = pushResolver.getMergedBlockDirs(testApp); pushResolver.applicationRemoved(testApp, true); // Since the cleanup happen in a different thread, check few times to see if the merge dirs gets @@ -522,7 +523,7 @@ public class RemoteBlockPushResolverSuite { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4])); callback1.onComplete(callback1.getID()); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo(); @@ -530,7 +531,7 @@ public class RemoteBlockPushResolverSuite { TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile(); testIndexFile.close(); StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any IOExceptions because number of IOExceptions are less than // the threshold but the update to index file will be unsuccessful. @@ -539,15 +540,15 @@ public class RemoteBlockPushResolverSuite { // Restore the index stream so it can write successfully again. testIndexFile.restore(); StreamCallbackWithID callback3 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); callback3.onData(callback3.getID(), ByteBuffer.wrap(new byte[2])); callback3.onComplete(callback3.getID()); assertEquals("index position", 24, testIndexFile.getPos()); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); validateMergeStatuses(statuses, new int[] {0}, new long[] {11}); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}}); } @Test @@ -555,7 +556,7 @@ public class RemoteBlockPushResolverSuite { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4])); callback1.onComplete(callback1.getID()); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo(); @@ -563,7 +564,7 @@ public class RemoteBlockPushResolverSuite { TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile(); testIndexFile.close(); StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any IOExceptions because number of IOExceptions are less than // the threshold but the update to index file will be unsuccessful. @@ -573,11 +574,11 @@ public class RemoteBlockPushResolverSuite { // Restore the index stream so it can write successfully again. testIndexFile.restore(); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); assertEquals("index position", 24, testIndexFile.getPos()); validateMergeStatuses(statuses, new int[] {0}, new long[] {9}); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}}); } @Test @@ -585,7 +586,7 @@ public class RemoteBlockPushResolverSuite { useTestFiles(false, true); RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4])); callback1.onComplete(callback1.getID()); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo(); @@ -594,7 +595,7 @@ public class RemoteBlockPushResolverSuite { long metaPosBeforeClose = testMetaFile.getPos(); testMetaFile.close(); StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any IOExceptions because number of IOExceptions are less than // the threshold but the update to index and meta file will be unsuccessful. @@ -604,16 +605,16 @@ public class RemoteBlockPushResolverSuite { // Restore the meta stream so it can write successfully again. testMetaFile.restore(); StreamCallbackWithID callback3 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); callback3.onData(callback3.getID(), ByteBuffer.wrap(new byte[2])); callback3.onComplete(callback3.getID()); assertEquals("index position", 24, partitionInfo.getIndexFile().getPos()); assertTrue("meta position", testMetaFile.getPos() > metaPosBeforeClose); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); validateMergeStatuses(statuses, new int[] {0}, new long[] {11}); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 7}, new int[][] {{0}, {1, 2}}); } @Test @@ -621,7 +622,7 @@ public class RemoteBlockPushResolverSuite { useTestFiles(false, true); RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[4])); callback1.onComplete(callback1.getID()); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback1.getPartitionInfo(); @@ -630,7 +631,7 @@ public class RemoteBlockPushResolverSuite { long metaPosBeforeClose = testMetaFile.getPos(); testMetaFile.close(); StreamCallbackWithID callback2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any IOExceptions because number of IOExceptions are less than // the threshold but the update to index and meta file will be unsuccessful. @@ -641,19 +642,19 @@ public class RemoteBlockPushResolverSuite { // Restore the meta stream so it can write successfully again. testMetaFile.restore(); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); assertEquals("index position", 24, indexFile.getPos()); assertTrue("meta position", testMetaFile.getPos() > metaPosBeforeClose); validateMergeStatuses(statuses, new int[] {0}, new long[] {9}); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}}); } @Test(expected = RuntimeException.class) public void testIOExceptionsExceededThreshold() throws IOException { RemoteBlockPushResolver.PushBlockStreamCallback callback = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo(); callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])); callback.onComplete(callback.getID()); @@ -662,7 +663,7 @@ public class RemoteBlockPushResolverSuite { for (int i = 1; i < 5; i++) { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); try { callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[2])); } catch (IOException ioe) { @@ -675,10 +676,10 @@ public class RemoteBlockPushResolverSuite { try { RemoteBlockPushResolver.PushBlockStreamCallback callback2 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 5, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0)); callback2.onData(callback.getID(), ByteBuffer.wrap(new byte[1])); } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_5_0", + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0", t.getMessage()); throw t; } @@ -689,7 +690,7 @@ public class RemoteBlockPushResolverSuite { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo(); callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])); callback.onComplete(callback.getID()); @@ -698,7 +699,7 @@ public class RemoteBlockPushResolverSuite { for (int i = 1; i < 5; i++) { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any exceptions but the exception count is increased. callback1.onComplete(callback1.getID()); @@ -709,11 +710,11 @@ public class RemoteBlockPushResolverSuite { try { RemoteBlockPushResolver.PushBlockStreamCallback callback2 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 5, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0)); callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[4])); callback2.onComplete(callback2.getID()); } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_5_0", + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0", t.getMessage()); throw t; } @@ -728,9 +729,9 @@ public class RemoteBlockPushResolverSuite { } try { pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 10, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 10, 0, 0)); } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_10_0", + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_10_0", t.getMessage()); throw t; } @@ -741,14 +742,14 @@ public class RemoteBlockPushResolverSuite { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo(); TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile(); testIndexFile.close(); for (int i = 1; i < 6; i++) { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); try { callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any exceptions but the exception count is increased. @@ -763,7 +764,7 @@ public class RemoteBlockPushResolverSuite { try { callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])); } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0", + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0", t.getMessage()); throw t; } @@ -774,14 +775,14 @@ public class RemoteBlockPushResolverSuite { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo(); TestMergeShuffleFile testIndexFile = (TestMergeShuffleFile) partitionInfo.getIndexFile(); testIndexFile.close(); for (int i = 1; i < 5; i++) { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, i, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); try { callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); // This will complete without any exceptions but the exception count is increased. @@ -793,7 +794,7 @@ public class RemoteBlockPushResolverSuite { assertEquals(4, partitionInfo.getNumIOExceptions()); RemoteBlockPushResolver.PushBlockStreamCallback callback2 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, 1, 0, 5, 0, 0)); + new PushBlockStream(TEST_APP, 1, 0, 0, 5, 0, 0)); callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[5])); // This is deferred callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])); @@ -820,15 +821,15 @@ public class RemoteBlockPushResolverSuite { public void testFailureWhileTruncatingFiles() throws IOException { useTestFiles(true, false); PushBlock[] pushBlocks = new PushBlock[] { - new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])), - new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])), - new PushBlock(0, 0, 1, ByteBuffer.wrap(new byte[5])), - new PushBlock(0, 1, 1, ByteBuffer.wrap(new byte[3])) + new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[2])), + new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[3])), + new PushBlock(0, 0, 0, 1, ByteBuffer.wrap(new byte[5])), + new PushBlock(0, 0, 1, 1, ByteBuffer.wrap(new byte[3])) }; pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks); RemoteBlockPushResolver.PushBlockStreamCallback callback = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); callback.onData(callback.getID(), ByteBuffer.wrap(new byte[2])); callback.onComplete(callback.getID()); RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo = callback.getPartitionInfo(); @@ -836,37 +837,37 @@ public class RemoteBlockPushResolverSuite { // Close the index file so truncate throws IOException testIndexFile.close(); MergeStatuses statuses = pushResolver.finalizeShuffleMerge( - new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); validateMergeStatuses(statuses, new int[] {1}, new long[] {8}); - MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 1); - validateChunks(TEST_APP, 0, 1, meta, new int[]{5, 3}, new int[][]{{0},{1}}); + MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 1); + validateChunks(TEST_APP, 0, 0, 1, meta, new int[]{5, 3}, new int[][]{{0},{1}}); } @Test public void testOnFailureInvokedMoreThanOncePerBlock() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); stream1.onFailure(stream1.getID(), new RuntimeException("forced error")); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); // On failure on stream1 gets invoked again and should cause no interference stream1.onFailure(stream1.getID(), new RuntimeException("2nd forced error")); StreamCallbackWithID stream3 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 3, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 3, 0, 0)); // This should be deferred as stream 2 is still the active stream stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2])); // Stream 2 writes more and completes stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4])); stream2.onComplete(stream2.getID()); stream3.onComplete(stream3.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {9, 2}, new int[][] {{1},{3}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {9, 2}, new int[][] {{1},{3}}); removeApplication(TEST_APP); } @@ -874,24 +875,24 @@ public class RemoteBlockPushResolverSuite { public void testFailureAfterDuplicateBlockDoesNotInterfereActiveStream() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); StreamCallbackWithID stream1Duplicate = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 0, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); stream1.onComplete(stream1.getID()); stream1Duplicate.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); // Should not change the current map id of the reduce partition stream1Duplicate.onFailure(stream2.getID(), new RuntimeException("forced error")); StreamCallbackWithID stream3 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0)); + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); // This should be deferred as stream 2 is still the active stream stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2])); RuntimeException failedEx = null; @@ -899,16 +900,16 @@ public class RemoteBlockPushResolverSuite { stream3.onComplete(stream3.getID()); } catch (RuntimeException re) { assertEquals( - "Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle", + "Couldn't find an opportunity to write block shufflePush_0_0_2_0 to merged shuffle", re.getMessage()); failedEx = re; } // Stream 2 writes more and completes stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4])); stream2.onComplete(stream2.getID()); - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0)); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}}); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}}); removeApplication(TEST_APP); if (failedEx != null) { throw failedEx; @@ -938,46 +939,42 @@ public class RemoteBlockPushResolverSuite { ByteBuffer.wrap(new byte[5]) }; StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(testApp, 1, 0, 0, 0, 0)); + new PushBlockStream(testApp, 1, 0, 0, 0, 0, 0)); for (ByteBuffer block : blocks) { stream1.onData(stream1.getID(), block); } stream1.onComplete(stream1.getID()); RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo = pushResolver.validateAndGetAppShuffleInfo(testApp); - Map> partitions = - appShuffleInfo.getPartitions(); - for (Map partitionMap : - partitions.values()) { - for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo : partitionMap.values()) { - assertTrue(partitionInfo.getDataChannel().isOpen()); - assertTrue(partitionInfo.getMetaFile().getChannel().isOpen()); - assertTrue(partitionInfo.getIndexFile().getChannel().isOpen()); - } + RemoteBlockPushResolver.AppShuffleMergePartitionsInfo partitions + = appShuffleInfo.getShuffles().get(0); + for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo : + partitions.getShuffleMergePartitions().values()) { + assertTrue(partitionInfo.getDataChannel().isOpen()); + assertTrue(partitionInfo.getMetaFile().getChannel().isOpen()); + assertTrue(partitionInfo.getIndexFile().getChannel().isOpen()); } Path[] attempt2LocalDirs = createLocalDirs(2); registerExecutor(testApp, prepareLocalDirs(attempt2LocalDirs, MERGE_DIRECTORY + "_" + ATTEMPT_ID_2), MERGE_DIRECTORY_META_2); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(testApp, 2, 0, 1, 0, 0)); + new PushBlockStream(testApp, 2, 0, 0, 1, 0, 0)); for (ByteBuffer block : blocks) { stream2.onData(stream2.getID(), block); } stream2.onComplete(stream2.getID()); closed.acquire(); // Check if all the file channels created for the first attempt are safely closed. - for (Map partitionMap : - partitions.values()) { - for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo : partitionMap.values()) { - assertFalse(partitionInfo.getDataChannel().isOpen()); - assertFalse(partitionInfo.getMetaFile().getChannel().isOpen()); - assertFalse(partitionInfo.getIndexFile().getChannel().isOpen()); - } + for (RemoteBlockPushResolver.AppShufflePartitionInfo partitionInfo : + partitions.getShuffleMergePartitions().values()) { + assertFalse(partitionInfo.getDataChannel().isOpen()); + assertFalse(partitionInfo.getMetaFile().getChannel().isOpen()); + assertFalse(partitionInfo.getIndexFile().getChannel().isOpen()); } try { pushResolver.receiveBlockDataAsStream( - new PushBlockStream(testApp, 1, 0, 1, 0, 0)); + new PushBlockStream(testApp, 1, 0, 0, 1, 0, 0)); } catch (IllegalArgumentException re) { assertEquals( "The attempt id 1 in this PushBlockStream message does not match " + @@ -1000,7 +997,7 @@ public class RemoteBlockPushResolverSuite { ByteBuffer.wrap(new byte[5]) }; StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(testApp, 1, 0, 0, 0, 0)); + new PushBlockStream(testApp, 1, 0, 0, 0, 0, 0)); for (ByteBuffer block : blocks) { stream1.onData(stream1.getID(), block); } @@ -1010,7 +1007,7 @@ public class RemoteBlockPushResolverSuite { prepareLocalDirs(attempt2LocalDirs, MERGE_DIRECTORY + "_" + ATTEMPT_ID_2), MERGE_DIRECTORY_META_2); try { - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0)); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0, 0)); } catch (IllegalArgumentException e) { assertEquals(e.getMessage(), String.format("The attempt id %s in this FinalizeShuffleMerge message does not " + @@ -1022,13 +1019,13 @@ public class RemoteBlockPushResolverSuite { @Test(expected = ClosedChannelException.class) public void testOngoingMergeOfBlockFromPreviousAttemptIsAborted() - throws IOException, InterruptedException { + throws IOException, InterruptedException { Semaphore closed = new Semaphore(0); pushResolver = new RemoteBlockPushResolver(conf) { @Override void closeAndDeletePartitionFilesIfNeeded( - AppShuffleInfo appShuffleInfo, - boolean cleanupLocalDirs) { + AppShuffleInfo appShuffleInfo, + boolean cleanupLocalDirs) { super.closeAndDeletePartitionFilesIfNeeded(appShuffleInfo, cleanupLocalDirs); closed.release(); } @@ -1045,7 +1042,7 @@ public class RemoteBlockPushResolverSuite { ByteBuffer.wrap(new byte[7]) }; StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(testApp, 1, 0, 0, 0, 0)); + new PushBlockStream(testApp, 1, 0, 0, 0, 0, 0)); // The onData callback should be called 4 times here before the onComplete callback. But a // register executor message arrives in shuffle service after the 2nd onData callback. The 3rd // onData callback should all throw ClosedChannelException as their channels are closed. @@ -1060,17 +1057,202 @@ public class RemoteBlockPushResolverSuite { stream1.onData(stream1.getID(), blocks[3]); } + @Test + public void testBlockPushWithOlderShuffleMergeId() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0, 0)); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + try { + // stream 1 push should be rejected as it is from an older shuffleMergeId + stream1.onComplete(stream1.getID()); + } catch(RuntimeException re) { + assertEquals("Block shufflePush_0_1_0_0 is received after merged shuffle is finalized or" + + " stale block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being" + + " pushed", re.getMessage()); + } + // stream 2 now completes + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0); + validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + } + + @Test + public void testFinalizeWithOlderShuffleMergeId() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0, 0)); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + try { + // stream 1 push should be rejected as it is from an older shuffleMergeId + stream1.onComplete(stream1.getID()); + } catch(RuntimeException re) { + assertEquals("Block shufflePush_0_1_0_0 is received after merged shuffle is finalized or" + + " stale block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being" + + " pushed", re.getMessage()); + } + // stream 2 now completes + stream2.onComplete(stream2.getID()); + try { + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1)); + } catch(RuntimeException re) { + assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale" + + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle" + + " is already being pushed", re.getMessage()); + } + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2)); + + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0); + validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + } + + @Test + public void testFinalizeOfDeterminateShuffle() throws IOException { + PushBlock[] pushBlocks = new PushBlock[] { + new PushBlock(0, 0, 0, 0, ByteBuffer.wrap(new byte[4])), + new PushBlock(0, 0, 1, 0, ByteBuffer.wrap(new byte[5])) + }; + pushBlockHelper(TEST_APP, NO_ATTEMPT_ID, pushBlocks); + MergeStatuses statuses = pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); + + RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo = + pushResolver.validateAndGetAppShuffleInfo(TEST_APP); + assertTrue("Metadata of determinate shuffle should be removed after finalize shuffle" + + " merge", appShuffleInfo.getShuffles().get(0) == null); + validateMergeStatuses(statuses, new int[] {0}, new long[] {9}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}}); + } + + @Test + public void testBlockFetchWithOlderShuffleMergeId() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 1, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 2, 0, 0, 0)); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + try { + // stream 1 push should be rejected as it is from an older shuffleMergeId + stream1.onComplete(stream1.getID()); + } catch(RuntimeException re) { + assertEquals("Block shufflePush_0_1_0_0 is received after merged shuffle is finalized or" + + " stale block push as shuffle blocks of a higher shuffleMergeId for the shuffle is being" + + " pushed", re.getMessage()); + } + // stream 2 now completes + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2)); + try { + pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + } catch(RuntimeException re) { + assertEquals("MergedBlockMeta fetch for shuffle 0 with shuffleMergeId 0 reduceId 0" + + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for" + + " the shuffle is available", re.getMessage()); + } + + try { + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1)); + } catch(RuntimeException re) { + assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale" + + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle" + + " is already being pushed", re.getMessage()); + } + try { + pushResolver.getMergedBlockData(TEST_APP, 0, 1, 0, 0); + } catch(RuntimeException re) { + assertEquals("MergedBlockData fetch for shuffle 0 with shuffleMergeId 1 reduceId 0" + + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for" + + " the shuffle is available", re.getMessage()); + } + + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0); + validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + } + + @Test + public void testCleanupOlderShuffleMergeId() throws IOException, InterruptedException { + Semaphore closed = new Semaphore(0); + pushResolver = new RemoteBlockPushResolver(conf) { + @Override + void closeAndDeletePartitionFiles(Map partitions) { + super.closeAndDeletePartitionFiles(partitions); + closed.release(); + } + }; + String testApp = "testCleanupOlderShuffleMergeId"; + registerExecutor(testApp, prepareLocalDirs(localDirs, MERGE_DIRECTORY), MERGE_DIRECTORY_META); + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 1, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 2, 0, 0, 0)); + RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo = + pushResolver.validateAndGetAppShuffleInfo(testApp); + closed.acquire(); + assertFalse("Data files on the disk should be cleaned up", + appShuffleInfo.getMergedShuffleDataFile(0, 1, 0).exists()); + assertFalse("Meta files on the disk should be cleaned up", + appShuffleInfo.getMergedShuffleMetaFile(0, 1, 0).exists()); + assertFalse("Index files on the disk should be cleaned up", + appShuffleInfo.getMergedShuffleIndexFile(0, 1, 0).exists()); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + // stream 2 now completes + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 2)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 2, 0); + validateChunks(testApp, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + + // Check whether the metadata is cleaned up or not + StreamCallbackWithID stream3 = + pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 3, 0, 0, 0)); + closed.acquire(); + stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2])); + stream3.onComplete(stream3.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 3)); + MergedBlockMeta mergedBlockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 3, 0); + validateChunks(testApp, 0, 3, 0, mergedBlockMeta, new int[]{2}, new int[][]{{0}}); + } + private void useTestFiles(boolean useTestIndexFile, boolean useTestMetaFile) throws IOException { pushResolver = new RemoteBlockPushResolver(conf) { @Override - AppShufflePartitionInfo newAppShufflePartitionInfo(String appId, int shuffleId, - int reduceId, File dataFile, File indexFile, File metaFile) throws IOException { + AppShufflePartitionInfo newAppShufflePartitionInfo( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId, + File dataFile, + File indexFile, + File metaFile) throws IOException { MergeShuffleFile mergedIndexFile = useTestIndexFile ? new TestMergeShuffleFile(indexFile) : new MergeShuffleFile(indexFile); MergeShuffleFile mergedMetaFile = useTestMetaFile ? new TestMergeShuffleFile(metaFile) : new MergeShuffleFile(metaFile); - return new AppShufflePartitionInfo(appId, shuffleId, reduceId, dataFile, mergedIndexFile, - mergedMetaFile); + return new AppShufflePartitionInfo(appId, shuffleId, shuffleMergeId, reduceId, dataFile, + mergedIndexFile, mergedMetaFile); } }; registerExecutor(TEST_APP, prepareLocalDirs(localDirs, MERGE_DIRECTORY), MERGE_DIRECTORY_META); @@ -1116,6 +1298,7 @@ public class RemoteBlockPushResolverSuite { private void validateChunks( String appId, int shuffleId, + int shuffleMergeId, int reduceId, MergedBlockMeta meta, int[] expectedSizes, @@ -1129,7 +1312,8 @@ public class RemoteBlockPushResolverSuite { } for (int i = 0; i < meta.getNumChunks(); i++) { FileSegmentManagedBuffer mb = - (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(appId, shuffleId, reduceId, i); + (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(appId, shuffleId, + shuffleMergeId, reduceId, i); assertEquals(expectedSizes[i], mb.getLength()); } } @@ -1141,7 +1325,8 @@ public class RemoteBlockPushResolverSuite { for (int i = 0; i < blocks.length; i++) { StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( new PushBlockStream( - appId, attemptId, blocks[i].shuffleId, blocks[i].mapIndex, blocks[i].reduceId, 0)); + appId, attemptId, blocks[i].shuffleId, blocks[i].shuffleMergeId, + blocks[i].mapIndex, blocks[i].reduceId, 0)); stream.onData(stream.getID(), blocks[i].buffer); stream.onComplete(stream.getID()); } @@ -1149,11 +1334,13 @@ public class RemoteBlockPushResolverSuite { private static class PushBlock { private final int shuffleId; + private final int shuffleMergeId; private final int mapIndex; private final int reduceId; private final ByteBuffer buffer; - PushBlock(int shuffleId, int mapIndex, int reduceId, ByteBuffer buffer) { + PushBlock(int shuffleId, int shuffleMergeId, int mapIndex, int reduceId, ByteBuffer buffer) { this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; this.mapIndex = mapIndex; this.reduceId = reduceId; this.buffer = buffer; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java index 91f319ded4..c79b01e87a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java @@ -29,10 +29,10 @@ public class FetchShuffleBlockChunksSuite { @Test public void testFetchShuffleBlockChunksEncodeDecode() { FetchShuffleBlockChunks shuffleBlockChunks = - new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}}); + new FetchShuffleBlockChunks("app0", "exec1", 0, 0, new int[] {0}, new int[][] {{0, 1}}); Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks()); int len = shuffleBlockChunks.encodedLength(); - Assert.assertEquals(45, len); + Assert.assertEquals(49, len); ByteBuf buf = Unpooled.buffer(len); shuffleBlockChunks.encode(buf); diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 4063d113b1..81e4c8f031 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} @@ -78,7 +79,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor) - extends Dependency[Product2[K, V]] { + extends Dependency[Product2[K, V]] with Logging { if (mapSideCombine) { require(aggregator.isDefined, "Map-side combine without Aggregator specified!") @@ -101,10 +102,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( // By default, shuffle merge is enabled for ShuffleDependency if push based shuffle // is enabled - private[this] var _shuffleMergeEnabled = - Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) && - // TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages - !rdd.isBarrier() + private[this] var _shuffleMergeEnabled = canShuffleMergeBeEnabled() private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = { _shuffleMergeEnabled = shuffleMergeEnabled @@ -124,6 +122,14 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( */ private[this] var _shuffleMergedFinalized: Boolean = false + /** + * shuffleMergeId is used to uniquely identify merging process of shuffle + * by an indeterminate stage attempt. + */ + private[this] var _shuffleMergeId: Int = 0 + + def shuffleMergeId: Int = _shuffleMergeId + def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { if (mergerLocs != null) { this.mergerLocs = mergerLocs @@ -150,6 +156,22 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( } } + def newShuffleMergeState(): Unit = { + _shuffleMergedFinalized = false + mergerLocs = Nil + _shuffleMergeId += 1 + } + + private def canShuffleMergeBeEnabled(): Boolean = { + val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) + if (isPushShuffleEnabled && rdd.isBarrier()) { + logWarning("Push-based shuffle is currently not supported for barrier stages") + } + isPushShuffleEnabled && + // TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages + !rdd.isBarrier() + } + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e605eea802..1b25ec5044 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -38,7 +38,7 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} import org.apache.spark.util._ /** @@ -1450,11 +1450,10 @@ private[spark] object MapOutputTracker extends Logging { val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) { // If MergeStatus is available for the given partition, add location of the // pre-merged shuffle partition for this partition ID. Here we create a - // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is - // a merged shuffle block. + // ShuffleMergedBlockId to indicate this is a merged shuffle block. splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, - SHUFFLE_PUSH_MAP_ID)) + ((ShuffleMergedBlockId(shuffleId, mergeStatus.shuffleMergeId, partId), + mergeStatus.totalSize, SHUFFLE_PUSH_MAP_ID)) // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper // shuffle partition blocks, fetch the original map produced shuffle partition blocks val mapStatusesWithIndex = mapStatuses.zipWithIndex diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9709ec16b8..b276de1392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1323,9 +1323,8 @@ private[spark] class DAGScheduler( // `findMissingPartitions()` returns all partitions every time. stage match { case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => - // TODO: SPARK-32923: Clean all push-based shuffle metadata like merge enabled and - // TODO: finalized as we are clearing all the merge results. mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() case _ => } @@ -2057,7 +2056,7 @@ private[spark] class DAGScheduler( // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled // TODO: during shuffleMergeFinalizeWaitSec shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host, - shuffleServiceLoc.port, shuffleId, + shuffleServiceLoc.port, shuffleId, stage.shuffleDep.shuffleMergeId, new MergeFinalizerListener { override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = { assert(shuffleId == statuses.shuffleId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index 77d8f8e040..6d16026453 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -43,14 +43,17 @@ import org.apache.spark.util.Utils */ private[spark] class MergeStatus( private[this] var loc: BlockManagerId, + private[this] var _shuffleMergeId: Int, private[this] var mapTracker: RoaringBitmap, private[this] var size: Long) extends Externalizable with ShuffleOutputStatus { - protected def this() = this(null, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1) // For deserialization only def location: BlockManagerId = loc + def shuffleMergeId: Int = _shuffleMergeId + def totalSize: Long = size def tracker: RoaringBitmap = mapTracker @@ -73,12 +76,14 @@ private[spark] class MergeStatus( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeInt(_shuffleMergeId) mapTracker.writeExternal(out) out.writeLong(size) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + _shuffleMergeId = in.readInt() mapTracker = new RoaringBitmap() mapTracker.readExternal(in) size = in.readLong() @@ -100,14 +105,20 @@ private[spark] object MergeStatus { assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length && mergeStatuses.bitmaps.length == mergeStatuses.sizes.length) val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port) + val shuffleMergeId = mergeStatuses.shuffleMergeId mergeStatuses.bitmaps.zipWithIndex.map { case (bitmap, index) => - val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index)) + val mergeStatus = new MergeStatus(mergerLoc, shuffleMergeId, bitmap, + mergeStatuses.sizes(index)) (mergeStatuses.reduceIds(index), mergeStatus) } } - def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = { - new MergeStatus(loc, bitmap, size) + def apply( + loc: BlockManagerId, + shuffleMergeId: Int, + bitmap: RoaringBitmap, + size: Long): MergeStatus = { + new MergeStatus(loc, shuffleMergeId, bitmap, size) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 9c50569c78..07928f8c52 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -116,28 +116,31 @@ private[spark] class IndexShuffleBlockResolver( private def getMergedBlockDataFile( appId: String, shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, dirs: Option[Array[String]] = None): File = { blockManager.diskBlockManager.getMergedShuffleFile( - ShuffleMergedDataBlockId(appId, shuffleId, reduceId), dirs) + ShuffleMergedDataBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs) } private def getMergedBlockIndexFile( appId: String, shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, dirs: Option[Array[String]] = None): File = { blockManager.diskBlockManager.getMergedShuffleFile( - ShuffleMergedIndexBlockId(appId, shuffleId, reduceId), dirs) + ShuffleMergedIndexBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs) } private def getMergedBlockMetaFile( appId: String, shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, dirs: Option[Array[String]] = None): File = { blockManager.diskBlockManager.getMergedShuffleFile( - ShuffleMergedMetaBlockId(appId, shuffleId, reduceId), dirs) + ShuffleMergedMetaBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs) } /** @@ -466,11 +469,13 @@ private[spark] class IndexShuffleBlockResolver( * knows how to consume local merged shuffle file as multiple chunks. */ override def getMergedBlockData( - blockId: ShuffleBlockId, + blockId: ShuffleMergedBlockId, dirs: Option[Array[String]]): Seq[ManagedBuffer] = { val indexFile = - getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs) - val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs) + getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.shuffleMergeId, + blockId.reduceId, dirs) + val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId, + blockId.shuffleMergeId, blockId.reduceId, dirs) // Load all the indexes in order to identify all chunks in the specified merged shuffle file. val size = indexFile.length.toInt val offsets = Utils.tryWithResource { @@ -493,13 +498,15 @@ private[spark] class IndexShuffleBlockResolver( * This is only used for reading local merged block meta data. */ override def getMergedBlockMeta( - blockId: ShuffleBlockId, + blockId: ShuffleMergedBlockId, dirs: Option[Array[String]]): MergedBlockMeta = { val indexFile = - getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs) + getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, + blockId.shuffleMergeId, blockId.reduceId, dirs) val size = indexFile.length.toInt val numChunks = (size / 8) - 1 - val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs) + val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId, + blockId.shuffleMergeId, blockId.reduceId, dirs) val chunkBitMaps = new FileSegmentManagedBuffer(transportConf, metaFile, 0L, metaFile.length) new MergedBlockMeta(numChunks, chunkBitMaps) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index 0d2462f233..56f915be87 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -69,7 +69,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { new BlockPushErrorHandler() { // For a connection exception against a particular host, we will stop pushing any // blocks to just that host and continue push blocks to other hosts. So, here push of - // all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore. + // all blocks will only stop when it is "Too Late" or "Invalid Block push. + // Also see updateStateAndCheckIfPushMore. override def shouldRetryError(t: Throwable): Boolean = { // If it is a FileNotFoundException originating from the client while pushing the shuffle // blocks to the server, then we stop pushing all the blocks because this indicates the @@ -77,8 +78,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { if (t.getCause != null && t.getCause.isInstanceOf[FileNotFoundException]) { return false } - // If the block is too late, there is no need to retry it - !Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX) + val errorStackTraceString = Throwables.getStackTraceAsString(t) + // If the block is too late or the invalid block push, there is no need to retry it + !errorStackTraceString.contains( + BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX) } } } @@ -99,8 +102,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { mapIndex: Int): Unit = { val numPartitions = dep.partitioner.numPartitions val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile, - partitionLengths, dep.getMergerLocs, transportConf) + val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, + dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs, transportConf) // Randomize the orders of the PushRequest, so different mappers pushing blocks at the same // time won't be pushing the same ranges of shuffle partitions. pushRequests ++= Utils.randomize(requests) @@ -335,6 +338,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { * @param numPartitions number of shuffle partitions in the shuffle file * @param partitionId map index of the current mapper * @param shuffleId shuffleId of current shuffle + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param dataFile shuffle data file * @param partitionLengths array of sizes of blocks in the shuffle data file * @param mergerLocs target locations to push blocks to @@ -347,6 +352,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { numPartitions: Int, partitionId: Int, shuffleId: Int, + shuffleMergeId: Int, dataFile: File, partitionLengths: Array[Long], mergerLocs: Seq[BlockManagerId], @@ -361,7 +367,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { for (reduceId <- 0 until numPartitions) { val blockSize = partitionLengths(reduceId) logDebug( - s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize") + s"Block ${ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId, + reduceId)} is of size $blockSize") // Skip 0-length blocks and blocks that are large enough if (blockSize > 0) { val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers), @@ -394,7 +401,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { // Only push blocks under the size limit if (blockSize <= maxBlockSizeToPush) { val blockSizeInt = blockSize.toInt - blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt)) + blocks += ((ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId, + reduceId), blockSizeInt)) // Only update currentReqOffset if the current block is the first in the request if (currentReqOffset == -1) { currentReqOffset = offset diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 49e59298cc..0f35f8c983 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.MergedBlockMeta -import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, ShuffleMergedBlockId} private[spark] /** @@ -44,12 +44,16 @@ trait ShuffleBlockResolver { /** * Retrieve the data for the specified merged shuffle block as multiple chunks. */ - def getMergedBlockData(blockId: ShuffleBlockId, dirs: Option[Array[String]]): Seq[ManagedBuffer] + def getMergedBlockData( + blockId: ShuffleMergedBlockId, + dirs: Option[Array[String]]): Seq[ManagedBuffer] /** * Retrieve the meta data for the specified merged shuffle block. */ - def getMergedBlockMeta(blockId: ShuffleBlockId, dirs: Option[Array[String]]): MergedBlockMeta + def getMergedBlockMeta( + blockId: ShuffleMergedBlockId, + dirs: Option[Array[String]]): MergedBlockMeta def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index db5862dec2..ce53f08bae 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -77,9 +77,11 @@ case class ShuffleBlockBatchId( @DeveloperApi case class ShuffleBlockChunkId( shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, chunkId: Int) extends BlockId { - override def name: String = "shuffleChunk_" + shuffleId + "_" + reduceId + "_" + chunkId + override def name: String = + "shuffleChunk_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + "_" + chunkId } @DeveloperApi @@ -100,15 +102,34 @@ case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) ex @Since("3.2.0") @DeveloperApi -case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId { - override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId +case class ShufflePushBlockId( + shuffleId: Int, + shuffleMergeId: Int, + mapIndex: Int, + reduceId: Int) extends BlockId { + override def name: String = "shufflePush_" + shuffleId + "_" + + shuffleMergeId + "_" + mapIndex + "_" + reduceId + "" } @Since("3.2.0") @DeveloperApi -case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: Int) extends BlockId { +case class ShuffleMergedBlockId( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int) extends BlockId { + override def name: String = "shuffleMerged_" + shuffleId + "_" + + shuffleMergeId + "_" + reduceId +} + +@Since("3.2.0") +@DeveloperApi +case class ShuffleMergedDataBlockId( + appId: String, + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int) extends BlockId { override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" + - appId + "_" + shuffleId + "_" + reduceId + ".data" + appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".data" } @Since("3.2.0") @@ -116,9 +137,10 @@ case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: Int case class ShuffleMergedIndexBlockId( appId: String, shuffleId: Int, + shuffleMergeId: Int, reduceId: Int) extends BlockId { override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" + - appId + "_" + shuffleId + "_" + reduceId + ".index" + appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".index" } @Since("3.2.0") @@ -126,9 +148,10 @@ case class ShuffleMergedIndexBlockId( case class ShuffleMergedMetaBlockId( appId: String, shuffleId: Int, + shuffleMergeId: Int, reduceId: Int) extends BlockId { override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" + - appId + "_" + shuffleId + "_" + reduceId + ".meta" + appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".meta" } @DeveloperApi @@ -172,11 +195,15 @@ object BlockId { val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r - val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r - val SHUFFLE_MERGED_DATA = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r - val SHUFFLE_MERGED_INDEX = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r - val SHUFFLE_MERGED_META = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r - val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_MERGED = "shuffleMerged_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_MERGED_DATA = + "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).data".r + val SHUFFLE_MERGED_INDEX = + "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).index".r + val SHUFFLE_MERGED_META = + "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).meta".r + val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -195,16 +222,22 @@ object BlockId { ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) - case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) => - ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt) - case SHUFFLE_MERGED_DATA(appId, shuffleId, reduceId) => - ShuffleMergedDataBlockId(appId, shuffleId.toInt, reduceId.toInt) - case SHUFFLE_MERGED_INDEX(appId, shuffleId, reduceId) => - ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt) - case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) => - ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt) - case SHUFFLE_CHUNK(shuffleId, reduceId, chunkId) => - ShuffleBlockChunkId(shuffleId.toInt, reduceId.toInt, chunkId.toInt) + case SHUFFLE_PUSH(shuffleId, shuffleMergeId, mapIndex, reduceId) => + ShufflePushBlockId(shuffleId.toInt, shuffleMergeId.toInt, mapIndex.toInt, + reduceId.toInt) + case SHUFFLE_MERGED(shuffleId, shuffleMergeId, reduceId) => + ShuffleMergedBlockId(shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt) + case SHUFFLE_MERGED_DATA(appId, shuffleId, shuffleMergeId, reduceId) => + ShuffleMergedDataBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt) + case SHUFFLE_MERGED_INDEX(appId, shuffleId, shuffleMergeId, reduceId) => + ShuffleMergedIndexBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt, + reduceId.toInt) + case SHUFFLE_MERGED_META(appId, shuffleId, shuffleMergeId, reduceId) => + ShuffleMergedMetaBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt, + reduceId.toInt) + case SHUFFLE_CHUNK(shuffleId, shuffleMergeId, reduceId, chunkId) => + ShuffleBlockChunkId(shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt, + chunkId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 43c7baf050..e57944aa16 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -748,7 +748,7 @@ private[spark] class BlockManager( * which will be memory efficient when performing certain operations. */ def getLocalMergedBlockData( - blockId: ShuffleBlockId, + blockId: ShuffleMergedBlockId, dirs: Array[String]): Seq[ManagedBuffer] = { shuffleManager.shuffleBlockResolver.getMergedBlockData(blockId, Some(dirs)) } @@ -757,7 +757,7 @@ private[spark] class BlockManager( * Get the local merged shuffle block meta data for the given block ID. */ def getLocalMergedBlockMeta( - blockId: ShuffleBlockId, + blockId: ShuffleMergedBlockId, dirs: Array[String]): MergedBlockMeta = { shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId, Some(dirs)) } diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 096ea247af..99138b670a 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -110,13 +110,14 @@ private class PushBasedFetchHelper( */ def createChunkBlockInfosFromMetaResponse( shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, blockSize: Long, bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { val approxChunkSize = blockSize / bitmaps.length val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() for (i <- bitmaps.indices) { - val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) chunksMetaMap.put(blockChunkId, bitmaps(i)) logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) @@ -134,37 +135,41 @@ private class PushBasedFetchHelper( def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { val sizeMap = req.blocks.map { case FetchBlockInfo(blockId, size, _) => - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) }.toMap val address = req.address val mergedBlocksMetaListener = new MergedBlocksMetaListener { - override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { - logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId) " + - s"from ${req.address.host}:${req.address.port}") + override def onSuccess(shuffleId: Int, shuffleMergeId: Int, reduceId: Int, + meta: MergedBlockMeta): Unit = { + logInfo(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") try { - iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId, - sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address)) + iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, shuffleMergeId, + reduceId, sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address)) } catch { case exception: Exception => logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " + - s"$reduceId) from ${req.address.host}:${req.address.port}", exception) + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", exception) iterator.addToResultsQueue( - PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address)) + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, + address)) } } - override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { + override def onFailure(shuffleId: Int, shuffleMergeId: Int, reduceId: Int, + exception: Throwable): Unit = { logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}", exception) iterator.addToResultsQueue( - PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address)) + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) } } req.blocks.foreach { block => - val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, - shuffleBlockId.reduceId, mergedBlocksMetaListener) + shuffleBlockId.shuffleMergeId, shuffleBlockId.reduceId, mergedBlocksMetaListener) } } @@ -241,11 +246,11 @@ private class PushBasedFetchHelper( localDirs: Array[String], blockManagerId: BlockManagerId): Unit = { try { - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) iterator.addToResultsQueue(PushMergedLocalMetaFetchResult( - shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(), - localDirs)) + shuffleBlockId.shuffleId, shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(), localDirs)) } catch { case e: Exception => // If we see an exception with reading a push-merged-local meta, we fallback to @@ -283,13 +288,13 @@ private class PushBasedFetchHelper( def initiateFallbackFetchForPushMergedBlock( blockId: BlockId, address: BlockManagerId): Unit = { - assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") // Increase the blocks processed since we will process another block in the next iteration of // the while loop in ShuffleBlockFetcherIterator.next(). val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = blockId match { - case shuffleBlockId: ShuffleBlockId => + case shuffleBlockId: ShuffleMergedBlockId => iterator.decreaseNumBlocksToFetch(1) mapOutputTracker.getMapSizesForMergeResult( shuffleBlockId.shuffleId, shuffleBlockId.reduceId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d03f20adf9..fd87f5e568 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -490,14 +490,14 @@ final class ShuffleBlockFetcherIterator( // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. // Based on these types, we decide to do batch fetch and create FetchRequests with // forMergedMetas set. - case ShuffleBlockChunkId(_, _, _) => + case ShuffleBlockChunkId(_, _, _, _) => if (curRequestSize >= targetRemoteRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, collectedRemoteRequests, enableBatchFetch = false) curRequestSize = curBlocks.map(_.size).sum } - case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => + case ShuffleMergedBlockId(_, _, _) => if (curBlocks.size >= maxBlocksInFlightPerAddress) { curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true) @@ -516,8 +516,8 @@ final class ShuffleBlockFetcherIterator( if (curBlocks.nonEmpty) { val (enableBatchFetch, forMergedMetas) = { curBlocks.head.blockId match { - case ShuffleBlockChunkId(_, _, _) => (false, false) - case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true) + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) case _ => (doBatchFetch, false) } } @@ -901,9 +901,10 @@ final class ShuffleBlockFetcherIterator( // a SuccessFetchResult or a FailureFetchResult. result = null - case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs) => + case PushMergedLocalMetaFetchResult( + shuffleId, shuffleMergeId, reduceId, bitmaps, localDirs) => // Fetch push-merged-local shuffle block data as multiple shuffle chunks - val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId) + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) try { val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) @@ -915,7 +916,8 @@ final class ShuffleBlockFetcherIterator( numBlocksToFetch += bufs.size bufs.zipWithIndex.foreach { case (buf, chunkId) => buf.retain() - val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId) + val shuffleChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, + chunkId) pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf, @@ -933,27 +935,29 @@ final class ShuffleBlockFetcherIterator( } result = null - case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address) => + case PushMergedRemoteMetaFetchResult( + shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps, address) => // The original meta request is processed so we decrease numBlocksToFetch and // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 numBlocksToFetch -= 1 val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( - shuffleId, reduceId, blockSize, bitmaps) + shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps) val additionalRemoteReqs = new ArrayBuffer[FetchRequest] collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) fetchRequests ++= additionalRemoteReqs // Set result to null to force another iteration. result = null - case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address) => + case PushMergedRemoteMetaFailedFetchResult( + shuffleId, shuffleMergeId, reduceId, address) => // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 // If we fail to fetch the meta of a push-merged block, we fall back to fetching the // original blocks. pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( - ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address) + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), address) // Set result to null to force another iteration. result = null } @@ -1421,6 +1425,8 @@ object ShuffleBlockFetcherIterator { * Result of a successful fetch of meta information for a remote push-merged block. * * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param blockSize size of each push-merged block. * @param bitmaps bitmaps for every chunk. @@ -1428,6 +1434,7 @@ object ShuffleBlockFetcherIterator { */ private[storage] case class PushMergedRemoteMetaFetchResult( shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, blockSize: Long, bitmaps: Array[RoaringBitmap], @@ -1437,11 +1444,14 @@ object ShuffleBlockFetcherIterator { * Result of a failure while fetching the meta information for a remote push-merged block. * * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param address BlockManager that the meta was fetched from. */ private[storage] case class PushMergedRemoteMetaFailedFetchResult( shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, address: BlockManagerId) extends FetchResult @@ -1449,12 +1459,15 @@ object ShuffleBlockFetcherIterator { * Result of a successful fetch of meta information for a push-merged-local block. * * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. * @param reduceId reduce id. * @param bitmaps bitmaps for every chunk. * @param localDirs local directories where the push-merged shuffle files are storedl */ private[storage] case class PushMergedLocalMetaFetchResult( shuffleId: Int, + shuffleMergeId: Int, reduceId: Int, bitmaps: Array[RoaringBitmap], localDirs: Array[String]) extends FetchResult diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index f4b47e2bb0..69cc8c1bce 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { private val conf = new SparkConf @@ -347,9 +347,9 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { bitmap.add(0) bitmap.add(1) - tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0, bitmap, 1000L)) - tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), 0, bitmap, 1000L)) assert(tracker.getNumAvailableMergeResults(10) == 2) tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) @@ -386,12 +386,12 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) - masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0, bitmap, 3000L)) slaveTracker.updateEpoch(masterTracker.getEpoch) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1), + Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1), (ShuffleBlockId(10, 2, 0), size1000, 2))))) masterTracker.stop() @@ -431,7 +431,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) - masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0, bitmap, 4000L)) slaveTracker.updateEpoch(masterTracker.getEpoch) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -523,7 +523,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { bitmap80.add(2) bitmap80.add(3) bitmap80.add(4) - tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0, bitmap80, 11)) val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) @@ -535,7 +535,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { // Prepare another MergeStatus that merges only 1 out of 5 blocks val bitmap20 = new RoaringBitmap() bitmap20.add(0) - tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0, bitmap20, 2)) val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) @@ -612,7 +612,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) } - masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), + masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), 0, bitmap1, 1000L)) val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf)) @@ -646,13 +646,13 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { val bitmap1 = new RoaringBitmap() bitmap1.add(0) bitmap1.add(1) - tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0, bitmap1, 1000L)) val bitmap2 = new RoaringBitmap() bitmap2.add(5) bitmap2.add(6) - tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), 0, bitmap2, 1000L)) assert(tracker.getNumAvailableMergeResults(10) == 2) tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index cc09e02487..312d1f8316 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -315,7 +315,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti shuffleMapStage: ShuffleMapStage): Unit = { if (shuffleMergeRegister) { for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) { - val mergeStatuses = Seq((part, makeMergeStatus(""))) + val mergeStatuses = Seq((part, makeMergeStatus("", + shuffleMapStage.shuffleDep.shuffleMergeId))) handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses) } if (shuffleMergeFinalize) { @@ -3726,9 +3727,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti (Success, makeMapStatus("hostA", parts)) }.toSeq) val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] - scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA")))) + scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA", + shuffleDep.shuffleMergeId)))) scheduler.handleShuffleMergeFinalized(shuffleMapStage) - scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA")))) + scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA", + shuffleDep.shuffleMergeId)))) assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 1) } @@ -3779,6 +3782,71 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == parts) } + test("SPARK-32923: handle stage failure for indeterminate map stage with push-based shuffle") { + initPushBasedShuffleConfs(conf) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed() + + // Check status for all failedStages + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.map(_.id) == Seq(1, 2)) + // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage + }.head.findMissingPartitions() == Seq(0)) + // The result stage is still waiting for its 2 tasks to complete + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(0, 1)) + // shuffleMergeId for indeterminate stages would start from 1 + assert(failedStages.collect { + case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId + }.forall(x => x == 1)) + scheduler.resubmitFailedStages() + + // The first task of the `shuffleMapRdd2` failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"), + null)) + + val newFailedStages = scheduler.failedStages.toSeq + assert(newFailedStages.map(_.id) == Seq(0, 1)) + // shuffleMergeId for indeterminate failed stages should be 2 + assert(failedStages.collect { + case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId + }.forall(x => x == 2)) + scheduler.resubmitFailedStages() + + // First shuffle map stage resubmitted and reran all tasks. + assert(taskSets(4).stageId == 0) + assert(taskSets(4).stageAttemptId == 1) + assert(taskSets(4).tasks.length == 2) + + // Finish all stage. + completeShuffleMapStageSuccessfully(0, 1, 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + // shuffleMergeId should be 2 for the attempt number 1 for stage 0 + assert(mapOutputTracker.shuffleStatuses.get(shuffleId1).forall( + _.mergeStatuses.forall(x => x.shuffleMergeId == 2))) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId1) == 2) + + completeShuffleMapStageSuccessfully(1, 2, 2, Seq("hostC", "hostD")) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + // shuffleMergeId should be 2 for the attempt number 2 for stage 1 + assert(mapOutputTracker.shuffleStatuses.get(shuffleId2).forall( + _.mergeStatuses.forall(x => x.shuffleMergeId == 3))) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId2) == 2) + + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + + // Job successful ended. + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -3843,8 +3911,8 @@ object DAGSchedulerSuite { BlockManagerId(host + "-exec", host, 12345) } - def makeMergeStatus(host: String, size: Long = 1000): MergeStatus = - MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size) + def makeMergeStatus(host: String, shuffleMergeId: Int, size: Long = 1000): MergeStatus = + MergeStatus(makeBlockManagerId(host), shuffleMergeId, mock(classOf[RoaringBitmap]), size) def addMergerLocs(locs: Seq[String]): Unit = { locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala index 2800be1cb9..26cdad8f94 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala @@ -96,7 +96,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { val blockPusher = new TestShuffleBlockPusher(conf) val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port)) val largeBlockSize = 2 * 1024 * 1024 - val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, + val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0, mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs, mock(classOf[TransportConf])) assert(pushRequests.length == 3) @@ -107,7 +107,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k") val blockPusher = new TestShuffleBlockPusher(conf) val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port)) - val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, + val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0, mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf])) assert(pushRequests.length == 2) verifyPushRequests(pushRequests, Seq(6, 1024)) @@ -117,7 +117,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1") val blockPusher = new TestShuffleBlockPusher(conf) val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port)) - val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, + val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0, mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf])) assert(pushRequests.length == 5) verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2)) @@ -220,7 +220,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { val errorHandler = pusher.createErrorHandler() assert( !errorHandler.shouldRetryError(new RuntimeException( - new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))) + new IllegalArgumentException( + BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)))) assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException()))) assert( errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException( @@ -233,7 +234,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { val errorHandler = pusher.createErrorHandler() assert( !errorHandler.shouldLogError(new RuntimeException( - new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))) + new IllegalArgumentException( + BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)))) assert(!errorHandler.shouldLogError(new RuntimeException( new IllegalArgumentException( BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))) @@ -284,7 +286,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { failBlock = false // Fail the first block with the too late exception. blockPushListener.onBlockPushFailure(blockId, new RuntimeException( - new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))) + new IllegalArgumentException( + BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))) } else { pushedBlocks += blockId blockPushListener.onBlockPushSuccess(blockId, mock(classOf[ManagedBuffer])) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 49c079cd4f..abe2b5694b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -27,7 +27,7 @@ import org.mockito.invocation.InvocationOnMock import org.roaringbitmap.RoaringBitmap import org.scalatest.BeforeAndAfterEach -import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.storage._ @@ -172,8 +172,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa test("getMergedBlockData should return expected FileSegmentManagedBuffer list") { val shuffleId = 1 + val shuffleMergeId = 0 val reduceId = 1 - val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.data" + val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.data" val dataFile = new File(tempDir.getAbsolutePath, dataFileName) val out = new FileOutputStream(dataFile) Utils.tryWithSafeFinally { @@ -181,12 +182,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index" + val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.index" generateMergedShuffleIndexFile(indexFileName) val resolver = new IndexShuffleBlockResolver(conf, blockManager) val dirs = Some(Array[String](tempDir.getAbsolutePath)) val managedBufferList = - resolver.getMergedBlockData(ShuffleBlockId(shuffleId, -1, reduceId), dirs) + resolver.getMergedBlockData(ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + dirs) assert(managedBufferList.size === 3) assert(managedBufferList(0).size === 10) assert(managedBufferList(1).size === 0) @@ -195,8 +197,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa test("getMergedBlockMeta should return expected MergedBlockMeta") { val shuffleId = 1 + val shuffleMergeId = 0 val reduceId = 1 - val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.meta" + val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.meta" val metaFile = new File(tempDir.getAbsolutePath, metaFileName) val chunkTracker = new RoaringBitmap() val metaFileOutputStream = new FileOutputStream(metaFile) @@ -216,13 +219,14 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa }{ outMeta.close() } - val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index" + val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.index" generateMergedShuffleIndexFile(indexFileName) val resolver = new IndexShuffleBlockResolver(conf, blockManager) val dirs = Some(Array[String](tempDir.getAbsolutePath)) val mergedBlockMeta = resolver.getMergedBlockMeta( - ShuffleBlockId(shuffleId, MapOutputTracker.SHUFFLE_PUSH_MAP_ID, reduceId), dirs) + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + dirs) assert(mergedBlockMeta.getNumChunks === 3) assert(mergedBlockMeta.readChunkBitmaps().size === 3) assert(mergedBlockMeta.readChunkBitmaps()(0).contains(1)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index e8c3c2df26..2fb8fa428f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -105,37 +105,52 @@ class BlockIdSuite extends SparkFunSuite { } test("shuffle merged data") { - val id = ShuffleMergedDataBlockId("app_000", 8, 9) - assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 9)) - assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 9)) - assert(id.name === "shuffleMerged_app_000_8_9.data") + val id = ShuffleMergedDataBlockId("app_000", 8, 0, 9) + assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 0, 9)) + assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 0, 9)) + assert(id.name === "shuffleMerged_app_000_8_0_9.data") assert(id.asRDDId === None) assert(id.appId === "app_000") + assert(id.shuffleMergeId == 0) assert(id.shuffleId=== 8) assert(id.reduceId === 9) assertSame(id, BlockId(id.toString)) } test("shuffle merged index") { - val id = ShuffleMergedIndexBlockId("app_000", 8, 9) - assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 9)) - assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 9)) - assert(id.name === "shuffleMerged_app_000_8_9.index") + val id = ShuffleMergedIndexBlockId("app_000", 8, 0, 9) + assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 0, 9)) + assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 0, 9)) + assert(id.name === "shuffleMerged_app_000_8_0_9.index") assert(id.asRDDId === None) assert(id.appId === "app_000") assert(id.shuffleId=== 8) + assert(id.shuffleMergeId == 0) assert(id.reduceId === 9) assertSame(id, BlockId(id.toString)) } test("shuffle merged meta") { - val id = ShuffleMergedMetaBlockId("app_000", 8, 9) - assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 9)) - assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 9)) - assert(id.name === "shuffleMerged_app_000_8_9.meta") + val id = ShuffleMergedMetaBlockId("app_000", 8, 0, 9) + assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 0, 9)) + assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 0, 9)) + assert(id.name === "shuffleMerged_app_000_8_0_9.meta") assert(id.asRDDId === None) assert(id.appId === "app_000") assert(id.shuffleId=== 8) + assert(id.shuffleMergeId == 0) + assert(id.reduceId === 9) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle merged block") { + val id = ShuffleMergedBlockId(8, 0, 9) + assertSame(id, ShuffleMergedBlockId(8, 0, 9)) + assertDifferent(id, ShuffleMergedBlockId(8, 1, 9)) + assert(id.name === "shuffleMerged_8_0_9") + assert(id.asRDDId === None) + assert(id.shuffleId=== 8) + assert(id.shuffleMergeId == 0) assert(id.reduceId === 9) assertSame(id, BlockId(id.toString)) } @@ -224,10 +239,10 @@ class BlockIdSuite extends SparkFunSuite { } test("shuffle chunk") { - val id = ShuffleBlockChunkId(1, 1, 0) - assertSame(id, ShuffleBlockChunkId(1, 1, 0)) - assertDifferent(id, ShuffleBlockChunkId(1, 1, 1)) - assert(id.name === "shuffleChunk_1_1_0") + val id = ShuffleBlockChunkId(1, 0, 1, 0) + assertSame(id, ShuffleBlockChunkId(1, 0, 1, 0)) + assertDifferent(id, ShuffleBlockChunkId(1, 0, 1, 1)) + assert(id.name === "shuffleChunk_1_0_1_0") assert(id.asRDDId === None) assert(id.shuffleId === 1) assert(id.reduceId === 1) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a5143cd95e..c22e1d0ca2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1059,15 +1059,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("SPARK-32922: fetch remote push-merged block meta") { val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), - toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), + toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, + SHUFFLE_PUSH_MAP_ID)), (BlockManagerId("remote-client-1", "remote-host-1", 1), toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)) ) val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), - ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), - ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer() + ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 1) -> createMockManagedBuffer() ) val blocksSem = new Semaphore(0) configureMockTransferForPushShuffle(blocksSem, blockChunks) @@ -1078,17 +1079,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap) when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] Future { val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + s"port = ${invocation.getArguments()(1)}, " + - s"shuffleId = $shuffleId, reduceId = $reduceId") + s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId") metaSem.acquire() - metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta) } }) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) @@ -1101,10 +1103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT metaSem.release() val (id3, _) = iterator.next() blocksSem.acquire() - assert(id3 === ShuffleBlockChunkId(0, 2, 0)) + assert(id3 === ShuffleBlockChunkId(0, 0, 2, 0)) val (id4, _) = iterator.next() blocksSem.acquire() - assert(id4 === ShuffleBlockChunkId(0, 2, 1)) + assert(id4 === ShuffleBlockChunkId(0, 0, 2, 1)) assert(!iterator.hasNext) } @@ -1113,7 +1115,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), - toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), + toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, + SHUFFLE_PUSH_MAP_ID)), (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))) val blockChunks = Map[BlockId, ManagedBuffer]( @@ -1127,13 +1130,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) val blocksSem = new Semaphore(0) configureMockTransferForPushShuffle(blocksSem, blockChunks) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] Future { - metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error")) + metaListener.onFailure(shuffleId, shuffleMergeId, reduceId, + new RuntimeException("forced error")) } }) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) @@ -1154,7 +1159,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), - toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, + SHUFFLE_PUSH_MAP_ID))) val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), @@ -1165,13 +1171,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) val blocksSem = new Semaphore(0) configureMockTransferForPushShuffle(blocksSem, blockChunks) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] Future { - metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error")) + metaListener.onFailure(shuffleId, shuffleMergeId, reduceId, + new RuntimeException("forced error")) } }) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) @@ -1225,7 +1233,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) - .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), dirsForMergedData) // Get a valid chunk meta for this test val bitmaps = Array(new RoaringBitmap) @@ -1236,7 +1244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } else { createMockPushMergedBlockMeta(bitmaps.length, bitmaps) } - when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + when(blockManager.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), dirsForMergedData)).thenReturn(pushMergedBlockMeta) when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( Seq((localBmId, @@ -1248,7 +1256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), (pushMergedBmId, toBlockList( - Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) } private def verifyLocalBlocksFromFallback(iterator: ShuffleBlockFetcherIterator): Unit = { @@ -1270,7 +1278,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareForFallbackToLocalBlocks( blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + .getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) @@ -1295,7 +1303,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareForFallbackToLocalBlocks( blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) @@ -1312,18 +1320,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), - (pushMergedBmId, toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), - ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3)), 2L, SHUFFLE_PUSH_MAP_ID))) + (pushMergedBmId, toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2), + ShuffleMergedBlockId(0, 0, 3)), 2L, SHUFFLE_PUSH_MAP_ID))) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + .getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), localDirs) // Create a valid chunk meta for partition 3 val bitmaps = Array(new RoaringBitmap) bitmaps(0).add(1) // chunk 0 has mapId 1 doReturn(createMockPushMergedBlockMeta(bitmaps.length, bitmaps)).when(blockManager) - .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs) + .getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 3), localDirs) // Return valid buffer for chunk in partition 3 doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) - .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 3), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) val (id1, _) = iterator.next() @@ -1335,7 +1343,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id4, _) = iterator.next() assert(id4 === ShuffleBlockId(0, 2, 2)) val (id5, _) = iterator.next() - assert(id5 === ShuffleBlockChunkId(0, 3, 0)) + assert(id5 === ShuffleBlockChunkId(0, 0, 3, 0)) assert(!iterator.hasNext) } @@ -1358,7 +1366,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Since bitmaps are null, this will fail reading the push-merged block meta causing fallback to // initiate. val pushMergedBlockMeta: MergedBlockMeta = createMockPushMergedBlockMeta(2, null) - when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + when(blockManager.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), dirsForMergedData)).thenReturn(pushMergedBlockMeta) when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( Seq((localBmId, @@ -1366,7 +1374,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1), toBlockList( - Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) // 1st instance of iterator.next() returns the original shuffle block (0, 0, 2) @@ -1385,7 +1393,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareForFallbackToLocalBlocks(blockManager, hostLocalDirs) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), Array("local-dir")) // host local read for a shuffle block doReturn(createMockManagedBuffer()).when(blockManager) .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir")) @@ -1416,7 +1424,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, hostLocalDirs) ++ hostLocalBlocks doThrow(new RuntimeException("Forced error")).when(blockManager) - .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), Array("local-dir")) // host Local read for this original shuffle block doReturn(createMockManagedBuffer()).when(blockManager) .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir")) @@ -1452,7 +1460,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(Seq({ new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100) })).when(blockManager).getLocalMergedBlockData( - ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + ShuffleMergedBlockId(0, 0, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) @@ -1466,7 +1474,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) val corruptBuffer = createMockManagedBuffer(2) doReturn(Seq({corruptBuffer})).when(blockManager) - .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs) val corruptStream = mock(classOf[InputStream]) when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) doReturn(corruptStream).when(corruptBuffer).createInputStream() @@ -1477,7 +1485,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("SPARK-32922: fallback to original blocks when failed to fetch remote shuffle chunk") { val blockChunks = Map[BlockId, ManagedBuffer]( - ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer() @@ -1489,13 +1497,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT bitmaps(1).add(4) bitmaps(1).add(5) val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, bitmaps) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] Future { - metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta) } }) val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( @@ -1506,10 +1515,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .thenReturn(fallbackBlocksByAddr.iterator) val iterator = createShuffleBlockIteratorWithDefaults(Map( BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) -> - toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, SHUFFLE_PUSH_MAP_ID))) + toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), + 12L, SHUFFLE_PUSH_MAP_ID))) val (id1, _) = iterator.next() blocksSem.acquire(1) - assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0)) val (id3, _) = iterator.next() blocksSem.acquire(3) assert(id3 === ShuffleBlockId(0, 3, 2)) @@ -1531,20 +1541,21 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksSem = new Semaphore(0) configureMockTransferForPushShuffle(blocksSem, blockChunks) val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, null) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] Future { - metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta) } }) val remoteMergedBlockMgrId = BlockManagerId( SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1) val iterator = createShuffleBlockIteratorWithDefaults( Map(remoteMergedBlockMgrId -> toBlockList( - Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) val (id1, _) = iterator.next() blocksSem.acquire(2) assert(id1 === ShuffleBlockId(0, 0, 2)) @@ -1556,10 +1567,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " + "pending shuffle chunks immediately") { val blockChunks = Map[BlockId, ManagedBuffer]( - ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(), // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks. - ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(), - ShuffleBlockChunkId(0, 2, 3) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 3) -> createMockManagedBuffer(), ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(), @@ -1574,17 +1585,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap) when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] Future { logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + s"port = ${invocation.getArguments()(1)}, " + - s"shuffleId = $shuffleId, reduceId = $reduceId") + s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId") metaSem.release() - metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta) } }) val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) @@ -1596,12 +1608,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults(Map( BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> - toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, SHUFFLE_PUSH_MAP_ID)), + toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), + 16L, SHUFFLE_PUSH_MAP_ID)), maxBytesInFlight = 4) metaSem.acquire(1) val (id1, _) = iterator.next() blocksSem.acquire(1) - assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0)) val regularBlocks = new mutable.HashSet[BlockId]() val (id2, _) = iterator.next() blocksSem.acquire(1) @@ -1623,12 +1636,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " + "pending shuffle chunks immediately which got deferred") { val blockChunks = Map[BlockId, ManagedBuffer]( - ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), - ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(), - ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 1) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 2) -> createMockManagedBuffer(), // ShuffleBlockChunkId(0, 2, 3) will cause failure as it is not in bock chunks - ShuffleBlockChunkId(0, 2, 4) -> createMockManagedBuffer(), - ShuffleBlockChunkId(0, 2, 5) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 4) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 0, 2, 5) -> createMockManagedBuffer(), ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(), @@ -1642,17 +1655,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap) when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) - when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] - val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int] + val reduceId = invocation.getArguments()(4).asInstanceOf[Int] Future { logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + s"port = ${invocation.getArguments()(1)}, " + - s"shuffleId = $shuffleId, reduceId = $reduceId") + s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId") metaSem.release() - metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta) } }) val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) @@ -1664,17 +1678,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults(Map( BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> - toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, SHUFFLE_PUSH_MAP_ID)), + toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 24L, + SHUFFLE_PUSH_MAP_ID)), maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1) metaSem.acquire(1) val (id1, _) = iterator.next() blocksSem.acquire(2) - assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0)) val (id2, _) = iterator.next() - assert(id2 === ShuffleBlockChunkId(0, 2, 1)) + assert(id2 === ShuffleBlockChunkId(0, 0, 2, 1)) val (id3, _) = iterator.next() blocksSem.acquire(1) - assert(id3 === ShuffleBlockChunkId(0, 2, 2)) + assert(id3 === ShuffleBlockChunkId(0, 0, 2, 2)) val regularBlocks = new mutable.HashSet[BlockId]() val (id4, _) = iterator.next() blocksSem.acquire(1)