[SPARK-32923][CORE][SHUFFLE] Handle indeterminate stage retries for push-based shuffle

[[SPARK-23243](https://issues.apache.org/jira/browse/SPARK-23243)] and [[SPARK-25341](https://issues.apache.org/jira/browse/SPARK-25341)] addressed cases of stage retries for indeterminate stage involving operations like repartition. This PR addresses the same issues in the context of push-based shuffle. Currently there is no way to distinguish the current execution of a stage for a shuffle ID. Therefore the changes explained below are necessary.

Core changes are summarized as follows:

1. Introduce a new variable `shuffleMergeId` in `ShuffleDependency` which is monotonically increasing value tracking the temporal ordering of execution of <stage-id, stage-attempt-id> for a shuffle ID.
2. Correspondingly make changes in the push-based shuffle protocol layer in `MergedShuffleFileManager`, `BlockStoreClient` passing the `shuffleMergeId` in order to keep track of the shuffle output in separate files on the shuffle service side.
3. `DAGScheduler` increments the `shuffleMergeId` tracked in `ShuffleDependency` in the cases of a indeterministic stage execution
4. Deterministic stage will have `shuffleMergeId` set to 0 as no special handling is needed in this case and indeterminate stage will have `shuffleMergeId` starting from 1.

New protocol changes are needed due to the reasons explained above.

No

Added new unit tests in `RemoteBlockPushResolverSuite, DAGSchedulerSuite, BlockIdSuite, ErrorHandlerSuite`

Closes #33034 from venkata91/SPARK-32923.

Authored-by: Venkata krishnan Sowrirajan <vsowrirajan@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
(cherry picked from commit c039d99812)
Signed-off-by: Mridul Muralidharan <mridulatgmail.com>
This commit is contained in:
Venkata krishnan Sowrirajan 2021-08-01 23:16:33 -05:00 committed by Mridul Muralidharan
parent cd3ab00382
commit a9b32b390c
39 changed files with 1502 additions and 722 deletions

View file

@ -206,12 +206,15 @@ public class TransportClient implements Closeable {
*
* @param appId applicationId.
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param callback callback the handle the reply.
*/
public void sendMergedBlockMetaReq(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId,
MergedBlockMetaResponseCallback callback) {
long requestId = requestId();
@ -222,7 +225,8 @@ public class TransportClient implements Closeable {
handler.addRpcRequest(requestId, callback);
RpcChannelListener listener = new RpcChannelListener(requestId, callback);
channel.writeAndFlush(
new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener);
new MergedBlockMetaRequest(requestId, appId, shuffleId, shuffleMergeId,
reduceId)).addListener(listener);
}
/**

View file

@ -32,13 +32,20 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
public final long requestId;
public final String appId;
public final int shuffleId;
public final int shuffleMergeId;
public final int reduceId;
public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) {
public MergedBlockMetaRequest(
long requestId,
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId) {
super(null, false);
this.requestId = requestId;
this.appId = appId;
this.shuffleId = shuffleId;
this.shuffleMergeId = shuffleMergeId;
this.reduceId = reduceId;
}
@ -49,7 +56,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
@Override
public int encodedLength() {
return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4;
return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4 + 4;
}
@Override
@ -57,6 +64,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
buf.writeLong(requestId);
Encoders.Strings.encode(buf, appId);
buf.writeInt(shuffleId);
buf.writeInt(shuffleMergeId);
buf.writeInt(reduceId);
}
@ -64,21 +72,23 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
long requestId = buf.readLong();
String appId = Encoders.Strings.decode(buf);
int shuffleId = buf.readInt();
int shuffleMergeId = buf.readInt();
int reduceId = buf.readInt();
return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
return new MergedBlockMetaRequest(requestId, appId, shuffleId, shuffleMergeId, reduceId);
}
@Override
public int hashCode() {
return Objects.hashCode(requestId, appId, shuffleId, reduceId);
return Objects.hashCode(requestId, appId, shuffleId, shuffleMergeId, reduceId);
}
@Override
public boolean equals(Object other) {
if (other instanceof MergedBlockMetaRequest) {
MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId
&& Objects.equal(appId, o.appId);
return requestId == o.requestId && shuffleId == o.shuffleId &&
shuffleMergeId == o.shuffleMergeId && reduceId == o.reduceId &&
Objects.equal(appId, o.appId);
}
return false;
}
@ -89,6 +99,7 @@ public class MergedBlockMetaRequest extends AbstractMessage implements RequestMe
.append("requestId", requestId)
.append("appId", appId)
.append("shuffleId", shuffleId)
.append("shuffleMergeId", shuffleMergeId)
.append("reduceId", reduceId)
.toString();
}

View file

@ -152,14 +152,14 @@ public class TransportRequestHandlerSuite {
TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L, null);
MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0);
MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0, 0);
requestHandler.handle(validMetaReq);
assertEquals(1, responseAndPromisePairs.size());
assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess);
assertEquals(2,
((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks());
MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1);
MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 0, 1);
requestHandler.handle(invalidMetaReq);
assertEquals(2, responseAndPromisePairs.size());
assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure);

View file

@ -167,6 +167,8 @@ public abstract class BlockStoreClient implements Closeable {
* @param host host of shuffle server
* @param port port of shuffle server.
* @param shuffleId shuffle ID of the shuffle to be finalized
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param listener the listener to receive MergeStatuses
*
* @since 3.1.0
@ -175,6 +177,7 @@ public abstract class BlockStoreClient implements Closeable {
String host,
int port,
int shuffleId,
int shuffleMergeId,
MergeFinalizerListener listener) {
throw new UnsupportedOperationException();
}
@ -185,6 +188,8 @@ public abstract class BlockStoreClient implements Closeable {
* @param host the host of the remote node.
* @param port the port of the remote node.
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param listener the listener to receive chunk counts.
*
@ -194,6 +199,7 @@ public abstract class BlockStoreClient implements Closeable {
String host,
int port,
int shuffleId,
int shuffleMergeId,
int reduceId,
MergedBlocksMetaListener listener) {
throw new UnsupportedOperationException();

View file

@ -55,12 +55,14 @@ public interface ErrorHandler {
class BlockPushErrorHandler implements ErrorHandler {
/**
* String constant used for generating exception messages indicating a block to be merged
* arrives too late on the server side, and also for later checking such exceptions on the
* client side. When we get a block push failure because of the block arrives too late, we
* will not retry pushing the block nor log the exception on the client side.
* arrives too late or stale block push in the case of indeterminate stage retries on the
* server side, and also for later checking such exceptions on the client side. When we get
* a block push failure because of the block push being stale or arrives too late, we will
* not retry pushing the block nor log the exception on the client side.
*/
public static final String TOO_LATE_MESSAGE_SUFFIX =
"received after merged shuffle is finalized";
public static final String TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX =
"received after merged shuffle is finalized or stale block push as shuffle blocks of a"
+ " higher shuffleMergeId for the shuffle is being pushed";
/**
* String constant used for generating exception messages indicating the server couldn't
@ -81,25 +83,54 @@ public interface ErrorHandler {
public static final String IOEXCEPTIONS_EXCEEDED_THRESHOLD_PREFIX =
"IOExceptions exceeded the threshold";
/**
* String constant used for generating exception messages indicating the server rejecting a
* shuffle finalize request since shuffle blocks of a higher shuffleMergeId for a shuffle is
* already being pushed. This typically happens in the case of indeterminate stage retries
* where if a stage attempt fails then the entirety of the shuffle output needs to be rolled
* back. For more details refer SPARK-23243, SPARK-25341 and SPARK-32923.
*/
public static final String STALE_SHUFFLE_FINALIZE_SUFFIX =
"stale shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the"
+ " shuffle is already being pushed";
@Override
public boolean shouldRetryError(Throwable t) {
// If it is a connection time-out or a connection closed exception, no need to retry.
// If it is a FileNotFoundException originating from the client while pushing the shuffle
// blocks to the server, even then there is no need to retry. We will still log this exception
// once which helps with debugging.
// blocks to the server, even then there is no need to retry. We will still log this
// exception once which helps with debugging.
if (t.getCause() != null && (t.getCause() instanceof ConnectException ||
t.getCause() instanceof FileNotFoundException)) {
return false;
}
// If the block is too late, there is no need to retry it
return !Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
String errorStackTrace = Throwables.getStackTraceAsString(t);
// If the block is too late or stale block push, there is no need to retry it
return !errorStackTrace.contains(TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX);
}
@Override
public boolean shouldLogError(Throwable t) {
String errorStackTrace = Throwables.getStackTraceAsString(t);
return !errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) &&
!errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX);
return !(errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) ||
errorStackTrace.contains(TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
}
}
class BlockFetchErrorHandler implements ErrorHandler {
public static final String STALE_SHUFFLE_BLOCK_FETCH =
"stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for the"
+ " shuffle is available";
@Override
public boolean shouldRetryError(Throwable t) {
return !Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_BLOCK_FETCH);
}
@Override
public boolean shouldLogError(Throwable t) {
return !Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_BLOCK_FETCH);
}
}
}

View file

@ -218,7 +218,8 @@ public class ExternalBlockHandler extends RpcHandler
callback.onSuccess(statuses.toByteBuffer());
} catch(IOException e) {
throw new RuntimeException(String.format("Error while finalizing shuffle merge "
+ "for application %s shuffle %d", msg.appId, msg.shuffleId), e);
+ "for application %s shuffle %d with shuffleMergeId %d", msg.appId, msg.shuffleId,
msg.shuffleMergeId), e);
} finally {
responseDelayContext.stop();
}
@ -237,7 +238,7 @@ public class ExternalBlockHandler extends RpcHandler
checkAuth(client, metaRequest.appId);
MergedBlockMeta mergedMeta =
mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId,
metaRequest.reduceId);
metaRequest.shuffleMergeId, metaRequest.reduceId);
logger.debug(
"Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ",
metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId,
@ -364,7 +365,6 @@ public class ExternalBlockHandler extends RpcHandler
private int index = 0;
private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
private final int size;
private boolean requestForMergedBlockChunks;
ManagedBufferIterator(OpenBlocks msg) {
String appId = msg.appId;
@ -377,13 +377,14 @@ public class ExternalBlockHandler extends RpcHandler
size = mapIdAndReduceIds.length;
blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
} else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
requestForMergedBlockChunks = true;
} else if (blockId0Parts.length == 5 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
final int shuffleMergeId = Integer.parseInt(blockId0Parts[2]);
final int[] reduceIdAndChunkIds = shuffleReduceIdAndChunkIds(blockIds, shuffleId,
shuffleMergeId);
size = reduceIdAndChunkIds.length;
blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId,
reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]);
shuffleMergeId, reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]);
} else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
size = rddAndSplitIds.length;
@ -407,27 +408,64 @@ public class ExternalBlockHandler extends RpcHandler
return rddAndSplitIds;
}
/**
* @param blockIds Regular shuffle blockIds starts with SHUFFLE_BLOCK_ID to be parsed
* @param shuffleId shuffle blocks shuffleId
* @return mapId and reduceIds of the shuffle blocks in the same order as that of the blockIds
*
* Regular shuffle blocks format should be shuffle_$shuffleId_$mapId_$reduceId
*/
private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
// For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds.
// For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds.
final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length];
final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
for (int i = 0; i < blockIds.length; i++) {
String[] blockIdParts = blockIds[i].split("_");
if (blockIdParts.length != 4
|| (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID))
|| (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) {
if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_BLOCK_ID)) {
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
}
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockIds[i]);
}
// For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId.
primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]);
// For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId.
primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
// mapId
mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
// reduceId
mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
}
return primaryIdAndSecondaryIds;
return mapIdAndReduceIds;
}
/**
* @param blockIds Shuffle merged chunks starts with SHUFFLE_CHUNK_ID to be parsed
* @param shuffleId shuffle blocks shuffleId
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @return reduceId and chunkIds of the shuffle chunks in the same order as that of the
* blockIds
*
* Shuffle merged chunks format should be
* shuffleChunk_$shuffleId_$shuffleMergeId_$reduceId_$chunkId
*/
private int[] shuffleReduceIdAndChunkIds(
String[] blockIds,
int shuffleId,
int shuffleMergeId) {
final int[] reduceIdAndChunkIds = new int[2 * blockIds.length];
for(int i = 0; i < blockIds.length; i++) {
String[] blockIdParts = blockIds[i].split("_");
if (blockIdParts.length != 5 || !blockIdParts[0].equals(SHUFFLE_CHUNK_ID)) {
throw new IllegalArgumentException("Unexpected shuffle chunk id format: " + blockIds[i]);
}
if (Integer.parseInt(blockIdParts[1]) != shuffleId ||
Integer.parseInt(blockIdParts[2]) != shuffleMergeId) {
throw new IllegalArgumentException(String.format("Expected shuffleId = %s"
+ " and shuffleMergeId = %s but got %s", shuffleId, shuffleMergeId, blockIds[i]));
}
// reduceId
reduceIdAndChunkIds[2 * i] = Integer.parseInt(blockIdParts[3]);
// chunkId
reduceIdAndChunkIds[2 * i + 1] = Integer.parseInt(blockIdParts[4]);
}
return reduceIdAndChunkIds;
}
@Override
@ -511,12 +549,14 @@ public class ExternalBlockHandler extends RpcHandler
private final String appId;
private final int shuffleId;
private final int shuffleMergeId;
private final int[] reduceIds;
private final int[][] chunkIds;
ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
appId = msg.appId;
shuffleId = msg.shuffleId;
shuffleMergeId = msg.shuffleMergeId;
reduceIds = msg.reduceIds;
chunkIds = msg.chunkIds;
// reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
@ -533,7 +573,7 @@ public class ExternalBlockHandler extends RpcHandler
@Override
public ManagedBuffer next() {
ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData(
appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]));
appId, shuffleId, shuffleMergeId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]));
if (chunkIdx < chunkIds[reduceIdx].length - 1) {
chunkIdx += 1;
} else {
@ -580,12 +620,20 @@ public class ExternalBlockHandler extends RpcHandler
@Override
public ManagedBuffer getMergedBlockData(
String appId, int shuffleId, int reduceId, int chunkId) {
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId,
int chunkId) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
@Override
public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) {
public MergedBlockMeta getMergedBlockMeta(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}

View file

@ -172,12 +172,14 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
String host,
int port,
int shuffleId,
int shuffleMergeId,
MergeFinalizerListener listener) {
checkInit();
try {
TransportClient client = clientFactory.createClient(host, port);
ByteBuffer finalizeShuffleMerge =
new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId).toByteBuffer();
new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId,
shuffleMergeId).toByteBuffer();
client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
@ -202,29 +204,31 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
String host,
int port,
int shuffleId,
int shuffleMergeId,
int reduceId,
MergedBlocksMetaListener listener) {
checkInit();
logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port,
shuffleId, reduceId);
logger.debug("Get merged blocks meta from {}:{} for shuffleId {} shuffleMergeId {}"
+ " reduceId {}", host, port, shuffleId, shuffleMergeId, reduceId);
try {
TransportClient client = clientFactory.createClient(host, port);
client.sendMergedBlockMetaReq(appId, shuffleId, reduceId,
client.sendMergedBlockMetaReq(appId, shuffleId, shuffleMergeId, reduceId,
new MergedBlockMetaResponseCallback() {
@Override
public void onSuccess(int numChunks, ManagedBuffer buffer) {
logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}",
shuffleId, reduceId);
listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer));
logger.trace("Successfully got merged block meta for shuffleId {} shuffleMergeId {}"
+ " reduceId {}", shuffleId, shuffleMergeId, reduceId);
listener.onSuccess(shuffleId, reduceId, shuffleMergeId,
new MergedBlockMeta(numChunks, buffer));
}
@Override
public void onFailure(Throwable e) {
listener.onFailure(shuffleId, reduceId, e);
listener.onFailure(shuffleId, shuffleMergeId, reduceId, e);
}
});
} catch (Exception e) {
listener.onFailure(shuffleId, reduceId, e);
listener.onFailure(shuffleId, shuffleMergeId, reduceId, e);
}
}

View file

@ -30,17 +30,21 @@ public interface MergedBlocksMetaListener extends EventListener {
* Called after successfully receiving the meta of a merged block.
*
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param meta contains meta information of a merged block.
*/
void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta);
void onSuccess(int shuffleId, int shuffleMergeId, int reduceId, MergedBlockMeta meta);
/**
* Called when there is an exception while fetching the meta of a merged block.
*
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param exception exception getting chunk counts.
*/
void onFailure(int shuffleId, int reduceId, Throwable exception);
void onFailure(int shuffleId, int shuffleMergeId, int reduceId, Throwable exception);
}

View file

@ -85,21 +85,34 @@ public interface MergedShuffleFileManager {
*
* @param appId application ID
* @param shuffleId shuffle ID
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reducer ID
* @param chunkId merged shuffle file chunk ID
* @return The {@link ManagedBuffer} for the given merged shuffle chunk
*/
ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId);
ManagedBuffer getMergedBlockData(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId,
int chunkId);
/**
* Get the meta information of a merged block.
*
* @param appId application ID
* @param shuffleId shuffle ID
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reducer ID
* @return meta information of a merged block
*/
MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId);
MergedBlockMeta getMergedBlockMeta(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId);
/**
* Get the local directories which stores the merged shuffle files.

View file

@ -22,7 +22,7 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Set;
import java.util.Map;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
@ -56,6 +56,8 @@ public class OneForOneBlockFetcher {
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_";
private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_";
private static final String SHUFFLE_BLOCK_SPLIT = "shuffle";
private static final String SHUFFLE_CHUNK_SPLIT = "shuffleChunk";
private final TransportClient client;
private final BlockTransferMessage message;
@ -125,63 +127,87 @@ public class OneForOneBlockFetcher {
String execId,
String[] blockIds) {
if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
return createFetchShuffleChunksMsg(appId, execId, blockIds);
} else {
return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
return createFetchShuffleBlocksMsg(appId, execId, blockIds);
}
}
/**
* Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
* analyzing the passed in blockIds.
*/
private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
private AbstractFetchShuffleBlocks createFetchShuffleBlocksMsg(
String appId,
String execId,
String[] blockIds,
boolean areMergedChunks) {
String[] blockIds) {
String[] firstBlock = splitBlockId(blockIds[0]);
int shuffleId = Integer.parseInt(firstBlock[1]);
boolean batchFetchEnabled = firstBlock.length == 5;
// In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
// is reduceId.
LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
Map<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
for (String blockId : blockIds) {
String[] blockIdParts = splitBlockId(blockId);
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId);
}
Number primaryId;
if (!areMergedChunks) {
primaryId = Long.parseLong(blockIdParts[2]);
} else {
primaryId = Integer.parseInt(blockIdParts[2]);
}
BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId,
long mapId = Long.parseLong(blockIdParts[2]);
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.computeIfAbsent(mapId,
id -> new BlocksInfo());
blocksInfoByPrimaryId.blockIds.add(blockId);
// If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a
// shuffleChunk block, then blockIdParts[3] = chunkId
blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
blocksInfoByMapId.blockIds.add(blockId);
blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[3]));
if (batchFetchEnabled) {
// It comes here only if the blockId is a regular shuffle block not a shuffleChunk block.
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
// FetchShuffleBlocks to store the start and end reduce id for range
// [startReduceId, endReduceId).
assert(blockIdParts.length == 5);
// blockIdParts[4] is the end reduce id for the batch range
blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[4]));
}
}
int[][] reduceIdsArray = getSecondaryIds(mapIdToBlocksInfo);
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
return new FetchShuffleBlocks(
appId, execId, shuffleId, mapIds, reduceIdsArray, batchFetchEnabled);
}
private AbstractFetchShuffleBlocks createFetchShuffleChunksMsg(
String appId,
String execId,
String[] blockIds) {
String[] firstBlock = splitBlockId(blockIds[0]);
int shuffleId = Integer.parseInt(firstBlock[1]);
int shuffleMergeId = Integer.parseInt(firstBlock[2]);
Map<Integer, BlocksInfo> reduceIdToBlocksInfo = new LinkedHashMap<>();
for (String blockId : blockIds) {
String[] blockIdParts = splitBlockId(blockId);
if (Integer.parseInt(blockIdParts[1]) != shuffleId ||
Integer.parseInt(blockIdParts[2]) != shuffleMergeId) {
throw new IllegalArgumentException(String.format("Expected shuffleId = %s and"
+ " shuffleMergeId = %s but got %s", shuffleId, shuffleMergeId, blockId));
}
int reduceId = Integer.parseInt(blockIdParts[3]);
BlocksInfo blocksInfoByReduceId = reduceIdToBlocksInfo.computeIfAbsent(reduceId,
id -> new BlocksInfo());
blocksInfoByReduceId.blockIds.add(blockId);
blocksInfoByReduceId.ids.add(Integer.parseInt(blockIdParts[4]));
}
int[][] chunkIdsArray = getSecondaryIds(reduceIdToBlocksInfo);
int[] reduceIds = Ints.toArray(reduceIdToBlocksInfo.keySet());
return new FetchShuffleBlockChunks(appId, execId, shuffleId, shuffleMergeId, reduceIds,
chunkIdsArray);
}
private int[][] getSecondaryIds(Map<? extends Number, BlocksInfo> primaryIdsToBlockInfo) {
// In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
// secondaryIds are chunkIds.
int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
int[][] secondaryIds = new int[primaryIdsToBlockInfo.size()][];
int blockIdIndex = 0;
int secIndex = 0;
for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) {
secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids);
for (BlocksInfo blocksInfo: primaryIdsToBlockInfo.values()) {
secondaryIds[secIndex++] = Ints.toArray(blocksInfo.ids);
// The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
// FetchShuffleBlockChunks because the shuffle data's return order should match the
@ -191,35 +217,27 @@ public class OneForOneBlockFetcher {
}
}
assert(blockIdIndex == this.blockIds.length);
Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
if (!areMergedChunks) {
long[] mapIds = Longs.toArray(primaryIds);
return new FetchShuffleBlocks(
appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled);
} else {
int[] reduceIds = Ints.toArray(primaryIds);
return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray);
}
return secondaryIds;
}
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
/**
* Split the blockId and return accordingly
* shuffleChunk - return shuffleId, shuffleMergeId, reduceId and chunkIds
* shuffle block - return shuffleId, mapId, reduceId
* shuffle batch block - return shuffleId, mapId, begin reduceId and end reduceId
*/
private String[] splitBlockId(String blockId) {
String[] blockIdParts = blockId.split("_");
// For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
// For single block id, the format contains shuffleId, mapId, educeId.
// For single block chunk id, the format contains shuffleId, reduceId, chunkId.
if (blockIdParts.length < 4 || blockIdParts.length > 5) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId);
}
if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
if (blockIdParts.length == 4 && !blockIdParts[0].equals(SHUFFLE_BLOCK_SPLIT)) {
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId);
}
if (blockIdParts.length == 4 &&
!(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
if (blockIdParts.length == 5 &&
!(blockIdParts[0].equals(SHUFFLE_BLOCK_SPLIT) ||
blockIdParts[0].equals(SHUFFLE_CHUNK_SPLIT))) {
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId);
}
return blockIdParts;
}

View file

@ -135,13 +135,14 @@ public class OneForOneBlockPusher {
assert buffers.containsKey(blockIds[i]) : "Could not find the block buffer for block "
+ blockIds[i];
String[] blockIdParts = blockIds[i].split("_");
if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) {
if (blockIdParts.length != 5 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) {
throw new IllegalArgumentException(
"Unexpected shuffle push block id format: " + blockIds[i]);
}
ByteBuffer header =
new PushBlockStream(appId, appAttemptId, Integer.parseInt(blockIdParts[1]),
Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]) , i).toByteBuffer();
Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]),
Integer.parseInt(blockIdParts[4]), i).toByteBuffer();
client.uploadStream(new NioManagedBuffer(header), buffers.get(blockIds[i]),
new BlockPushCallback(i, blockIds[i]));
}

View file

@ -28,6 +28,7 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
@ -78,6 +79,15 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
public static final String MERGE_DIR_KEY = "mergeDir";
public static final String ATTEMPT_ID_KEY = "attemptId";
private static final int UNDEFINED_ATTEMPT_ID = -1;
// Shuffles of determinate stages will have shuffleMergeId set to 0
private static final int DETERMINATE_SHUFFLE_MERGE_ID = 0;
// ConcurrentHashMap doesn't allow null for keys or values which is why this is required.
// Marker to identify finalized indeterminate shuffle partitions in the case of indeterminate
// stage retries.
@VisibleForTesting
public static final Map<Integer, AppShufflePartitionInfo> INDETERMINATE_SHUFFLE_FINALIZED =
Collections.emptyMap();
/**
* A concurrent hashmap where the key is the applicationId, and the value includes
@ -128,50 +138,79 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
/**
* Given the appShuffleInfo, shuffleId and reduceId that uniquely identifies a given shuffle
* partition of an application, retrieves the associated metadata. If not present and the
* corresponding merged shuffle does not exist, initializes the metadata.
* Given the appShuffleInfo, shuffleId, shuffleMergeId and reduceId that uniquely identifies
* a given shuffle partition of an application, retrieves the associated metadata. If not
* present and the corresponding merged shuffle does not exist, initializes the metadata.
*/
private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
AppShuffleInfo appShuffleInfo,
int shuffleId,
int reduceId) {
File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, reduceId);
ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
appShuffleInfo.partitions;
Map<Integer, AppShufflePartitionInfo> shufflePartitions =
partitions.compute(shuffleId, (id, map) -> {
if (map == null) {
int shuffleMergeId,
int reduceId) throws StaleBlockPushException {
ConcurrentMap<Integer, AppShuffleMergePartitionsInfo> shuffles = appShuffleInfo.shuffles;
AppShuffleMergePartitionsInfo shufflePartitionsWithMergeId =
shuffles.compute(shuffleId, (id, appShuffleMergePartitionsInfo) -> {
if (appShuffleMergePartitionsInfo == null) {
File dataFile =
appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId);
// If this partition is already finalized then the partitions map will not contain the
// shuffleId but the data file would exist. In that case the block is considered late.
// shuffleId for determinate stages but the data file would exist.
// In that case the block is considered late. In the case of indeterminate stages, most
// recent shuffleMergeId finalized would be pointing to INDETERMINATE_SHUFFLE_FINALIZED
if (dataFile.exists()) {
return null;
}
return new ConcurrentHashMap<>();
} else {
return map;
logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}"
+ " with shuffleMergeId {} for application {}_{}", shuffleId, shuffleMergeId,
appShuffleInfo.appId, appShuffleInfo.attemptId);
return new AppShuffleMergePartitionsInfo(shuffleMergeId, false);
}
} else {
// Reject the request as we have already seen a higher shuffleMergeId than the
// current incoming one
int latestShuffleMergeId = appShuffleMergePartitionsInfo.shuffleMergeId;
if (latestShuffleMergeId > shuffleMergeId) {
throw new StaleBlockPushException(String.format("Rejecting shuffle blocks push request"
+ " for shuffle %s with shuffleMergeId %s for application %s_%s as a higher"
+ " shuffleMergeId %s request is already seen", shuffleId, shuffleMergeId,
appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId));
} else if (latestShuffleMergeId == shuffleMergeId) {
return appShuffleMergePartitionsInfo;
} else {
// Higher shuffleMergeId seen for the shuffle ID meaning new stage attempt is being
// run for the shuffle ID. Close and clean up old shuffleMergeId files,
// happens in the indeterminate stage retries
logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}"
+ " with shuffleMergeId {} for application {}_{} since it is higher than the"
+ " latest shuffleMergeId {} already seen", shuffleId, shuffleMergeId,
appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId);
mergedShuffleCleaner.execute(() ->
closeAndDeletePartitionFiles(appShuffleMergePartitionsInfo.shuffleMergePartitions));
return new AppShuffleMergePartitionsInfo(shuffleMergeId, false);
}
}
});
if (shufflePartitions == null) {
// It only gets here when the shuffle is already finalized.
if (null == shufflePartitionsWithMergeId ||
INDETERMINATE_SHUFFLE_FINALIZED == shufflePartitionsWithMergeId.shuffleMergePartitions) {
return null;
}
return shufflePartitions.computeIfAbsent(reduceId, key -> {
// It only gets here when the key is not present in the map. This could either
// be the first time the merge manager receives a pushed block for a given application
// shuffle partition, or after the merged shuffle file is finalized. We handle these
// two cases accordingly by checking if the file already exists.
Map<Integer, AppShufflePartitionInfo> shuffleMergePartitions =
shufflePartitionsWithMergeId.shuffleMergePartitions;
return shuffleMergePartitions.computeIfAbsent(reduceId, key -> {
// It only gets here when the key is not present in the map. The first time the merge
// manager receives a pushed block for a given application shuffle partition.
File dataFile =
appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId);
File indexFile =
appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId);
File metaFile =
appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId);
appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId);
try {
if (dataFile.exists()) {
return null;
} else {
return newAppShufflePartitionInfo(
appShuffleInfo.appId, shuffleId, reduceId, dataFile, indexFile, metaFile);
}
return newAppShufflePartitionInfo(appShuffleInfo.appId, shuffleId, shuffleMergeId,
reduceId, dataFile, indexFile, metaFile);
} catch (IOException e) {
logger.error(
"Cannot create merged shuffle partition with data file {}, index file {}, and "
@ -179,7 +218,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
indexFile.getAbsolutePath(), metaFile.getAbsolutePath());
throw new RuntimeException(
String.format("Cannot initialize merged shuffle partition for appId %s shuffleId %s "
+ "reduceId %s", appShuffleInfo.appId, shuffleId, reduceId), e);
+ "shuffleMergeId %s reduceId %s", appShuffleInfo.appId, shuffleId, shuffleMergeId,
reduceId), e);
}
});
}
@ -188,19 +228,31 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
AppShufflePartitionInfo newAppShufflePartitionInfo(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId,
File dataFile,
File indexFile,
File metaFile) throws IOException {
return new AppShufflePartitionInfo(appId, shuffleId, reduceId, dataFile,
return new AppShufflePartitionInfo(appId, shuffleId, shuffleMergeId, reduceId, dataFile,
new MergeShuffleFile(indexFile), new MergeShuffleFile(metaFile));
}
@Override
public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) {
public MergedBlockMeta getMergedBlockMeta(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId) {
AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(appId);
AppShuffleMergePartitionsInfo partitionsInfo = appShuffleInfo.shuffles.get(shuffleId);
if (null != partitionsInfo && partitionsInfo.shuffleMergeId > shuffleMergeId) {
throw new RuntimeException(String.format(
"MergedBlockMeta fetch for shuffle %s with shuffleMergeId %s reduceId %s is %s",
shuffleId, shuffleMergeId, reduceId,
ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH));
}
File indexFile =
appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId);
if (!indexFile.exists()) {
throw new RuntimeException(String.format(
"Merged shuffle index file %s not found", indexFile.getPath()));
@ -208,7 +260,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
int size = (int) indexFile.length();
// First entry is the zero offset
int numChunks = (size / Long.BYTES) - 1;
File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId);
File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId);
if (!metaFile.exists()) {
throw new RuntimeException(String.format("Merged shuffle meta file %s not found",
metaFile.getPath()));
@ -216,21 +268,30 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
FileSegmentManagedBuffer chunkBitMaps =
new FileSegmentManagedBuffer(conf, metaFile, 0L, metaFile.length());
logger.trace(
"{} shuffleId {} reduceId {} num chunks {}", appId, shuffleId, reduceId, numChunks);
"{} shuffleId {} shuffleMergeId {} reduceId {} num chunks {}",
appId, shuffleId, shuffleMergeId, reduceId, numChunks);
return new MergedBlockMeta(numChunks, chunkBitMaps);
}
@SuppressWarnings("UnstableApiUsage")
@Override
public ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId) {
public ManagedBuffer getMergedBlockData(
String appId, int shuffleId, int shuffleMergeId, int reduceId, int chunkId) {
AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(appId);
File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, reduceId);
AppShuffleMergePartitionsInfo partitionsInfo = appShuffleInfo.shuffles.get(shuffleId);
if (null != partitionsInfo && partitionsInfo.shuffleMergeId > shuffleMergeId) {
throw new RuntimeException(String.format(
"MergedBlockData fetch for shuffle %s with shuffleMergeId %s reduceId %s is %s",
shuffleId, shuffleMergeId, reduceId,
ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH));
}
File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId);
if (!dataFile.exists()) {
throw new RuntimeException(String.format("Merged shuffle data file %s not found",
dataFile.getPath()));
}
File indexFile =
appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId);
try {
// If we get here, the merged shuffle file should have been properly finalized. Thus we can
// use the file length to determine the size of the merged shuffle block.
@ -270,18 +331,32 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
void closeAndDeletePartitionFilesIfNeeded(
AppShuffleInfo appShuffleInfo,
boolean cleanupLocalDirs) {
for (Map<Integer, AppShufflePartitionInfo> partitionMap : appShuffleInfo.partitions.values()) {
for (AppShufflePartitionInfo partitionInfo : partitionMap.values()) {
appShuffleInfo.shuffles.forEach((shuffleId, shuffleInfo) -> shuffleInfo.shuffleMergePartitions
.forEach((shuffleMergeId, partitionInfo) -> {
synchronized (partitionInfo) {
partitionInfo.closeAllFiles();
}
}
partitionInfo.closeAllFilesAndDeleteIfNeeded(false);
}
}));
if (cleanupLocalDirs) {
deleteExecutorDirs(appShuffleInfo);
}
}
/**
* Clean up all the AppShufflePartitionInfo for a specific shuffleMergeId. This is done
* since there is a higher shuffleMergeId request made for a shuffleId, therefore clean
* up older shuffleMergeId partitions. The cleanup will be executed in a separate thread.
*/
@VisibleForTesting
void closeAndDeletePartitionFiles(Map<Integer, AppShufflePartitionInfo> partitions) {
partitions
.forEach((partitionId, partitionInfo) -> {
synchronized (partitionInfo) {
partitionInfo.closeAllFilesAndDeleteIfNeeded(true);
}
});
}
/**
* Serially delete local dirs.
*/
@ -304,9 +379,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
@Override
public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) {
AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId);
final String streamId = String.format("%s_%d_%d_%d",
OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, msg.shuffleId, msg.mapIndex,
msg.reduceId);
final String streamId = String.format("%s_%d_%d_%d_%d",
OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, msg.shuffleId, msg.shuffleMergeId,
msg.mapIndex, msg.reduceId);
if (appShuffleInfo.attemptId != msg.appAttemptId) {
// If this Block belongs to a former application attempt, it is considered late,
// as only the blocks from the current application attempt will be merged
@ -317,12 +392,19 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
}
// Retrieve merged shuffle file metadata
AppShufflePartitionInfo partitionInfoBeforeCheck =
getOrCreateAppShufflePartitionInfo(appShuffleInfo, msg.shuffleId, msg.reduceId);
// Here partitionInfo will be null in 2 cases:
AppShufflePartitionInfo partitionInfoBeforeCheck;
try {
partitionInfoBeforeCheck = getOrCreateAppShufflePartitionInfo(appShuffleInfo, msg.shuffleId,
msg.shuffleMergeId, msg.reduceId);
} catch(StaleBlockPushException sbp) {
// Set partitionInfoBeforeCheck to null so that stale block push gets handled.
partitionInfoBeforeCheck = null;
}
// Here partitionInfo will be null in 3 cases:
// 1) The request is received for a block that has already been merged, this is possible due
// to the retry logic.
// 2) The request is received after the merged shuffle is finalized, thus is too late.
// 3) The request is received for a older shuffleMergeId, therefore the block push is rejected.
//
// For case 1, we will drain the data in the channel and just respond success
// to the client. This is required because the response of the previously merged
@ -345,6 +427,13 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
// to notify the client of the failure, so that it can properly halt pushing the remaining
// blocks upon receiving such failures to preserve resources on the server/client side.
//
// For case 3, we will also drain the data in the channel, but throw an exception in
// {@link org.apache.spark.network.client.StreamCallback#onComplete(String)}. This way,
// the client will be notified of the failure but the channel will remain active. It is
// important to notify the client of the failure, so that it can properly halt pushing the
// remaining blocks upon receiving such failures to preserve resources on the server/client
// side.
//
// Speculative execution would also raise a possible scenario with duplicate blocks. Although
// speculative execution would kill the slower task attempt, leading to only 1 task attempt
// succeeding in the end, there is no guarantee that only one copy of the block will be
@ -353,18 +442,19 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
// getting killed. When this happens, we need to distinguish the duplicate blocks as they
// arrive. More details on this is explained in later comments.
// Track if the block is received after shuffle merge finalize
final boolean isTooLate = partitionInfoBeforeCheck == null;
// Check if the given block is already merged by checking the bitmap against the given map index
final AppShufflePartitionInfo partitionInfo = partitionInfoBeforeCheck != null
&& partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null
: partitionInfoBeforeCheck;
// Track if the block is received after shuffle merge finalized or from an older
// shuffleMergeId attempt.
final boolean isStaleBlockOrTooLate = partitionInfoBeforeCheck == null;
// Check if the given block is already merged by checking the bitmap against the given map
// index
final AppShufflePartitionInfo partitionInfo = isStaleBlockOrTooLate ? null :
partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null : partitionInfoBeforeCheck;
if (partitionInfo != null) {
return new PushBlockStreamCallback(
this, appShuffleInfo, streamId, partitionInfo, msg.mapIndex);
} else {
// For a duplicate block or a block which is late, respond back with a callback that handles
// them differently.
// For a duplicate block or a block which is late or stale block from an older
// shuffleMergeId, respond back with a callback that handles them differently.
return new StreamCallbackWithID() {
@Override
public String getID() {
@ -379,11 +469,11 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
@Override
public void onComplete(String streamId) {
if (isTooLate) {
if (isStaleBlockOrTooLate) {
// Throw an exception here so the block data is drained from channel and server
// responds RpcFailure to the client.
throw new RuntimeException(String.format("Block %s %s", streamId,
ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
}
// For duplicate block that is received before the shuffle merge finalizes, the
// server should respond success to the client.
@ -397,9 +487,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
@Override
public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException {
logger.info("Finalizing shuffle {} from Application {}_{}.",
msg.shuffleId, msg.appId, msg.appAttemptId);
public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) {
logger.info("Finalizing shuffle {} with shuffleMergeId {} from Application {}_{}.",
msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId);
AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId);
if (appShuffleInfo.attemptId != msg.appAttemptId) {
// If this Block belongs to a former application attempt, it is considered late,
@ -410,17 +500,42 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
+ "with the current attempt id %s stored in shuffle service for application %s",
msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
}
Map<Integer, AppShufflePartitionInfo> shufflePartitions =
appShuffleInfo.partitions.remove(msg.shuffleId);
MergeStatuses mergeStatuses;
if (shufflePartitions == null || shufflePartitions.isEmpty()) {
mergeStatuses =
new MergeStatuses(msg.shuffleId, new RoaringBitmap[0], new int[0], new long[0]);
AtomicReference<Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitionsRef =
new AtomicReference<>(null);
// Metadata of the determinate stage shuffle can be safely removed as part of finalizing
// shuffle merge. Currently once the shuffle is finalized for a determinate stages, retry
// stages of the same shuffle will have shuffle push disabled.
if (msg.shuffleMergeId == DETERMINATE_SHUFFLE_MERGE_ID) {
AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo =
appShuffleInfo.shuffles.remove(msg.shuffleId);
if (appShuffleMergePartitionsInfo != null) {
shuffleMergePartitionsRef.set(appShuffleMergePartitionsInfo.shuffleMergePartitions);
}
} else {
List<RoaringBitmap> bitmaps = new ArrayList<>(shufflePartitions.size());
List<Integer> reduceIds = new ArrayList<>(shufflePartitions.size());
List<Long> sizes = new ArrayList<>(shufflePartitions.size());
for (AppShufflePartitionInfo partition: shufflePartitions.values()) {
appShuffleInfo.shuffles.compute(msg.shuffleId, (id, value) -> {
if (null == value || msg.shuffleMergeId != value.shuffleMergeId ||
INDETERMINATE_SHUFFLE_FINALIZED == value.shuffleMergePartitions) {
throw new RuntimeException(String.format(
"Shuffle merge finalize request for shuffle %s with" + " shuffleMergeId %s is %s",
msg.shuffleId, msg.shuffleMergeId,
ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE_SUFFIX));
} else {
shuffleMergePartitionsRef.set(value.shuffleMergePartitions);
return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true);
}
});
}
Map<Integer, AppShufflePartitionInfo> shuffleMergePartitions = shuffleMergePartitionsRef.get();
MergeStatuses mergeStatuses;
if (null == shuffleMergePartitions || shuffleMergePartitions.isEmpty()) {
mergeStatuses =
new MergeStatuses(msg.shuffleId, msg.shuffleMergeId,
new RoaringBitmap[0], new int[0], new long[0]);
} else {
List<RoaringBitmap> bitmaps = new ArrayList<>(shuffleMergePartitions.size());
List<Integer> reduceIds = new ArrayList<>(shuffleMergePartitions.size());
List<Long> sizes = new ArrayList<>(shuffleMergePartitions.size());
for (AppShufflePartitionInfo partition: shuffleMergePartitions.values()) {
synchronized (partition) {
try {
// This can throw IOException which will marks this shuffle partition as not merged.
@ -432,16 +547,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
logger.warn("Exception while finalizing shuffle partition {}_{} {} {}", msg.appId,
msg.appAttemptId, msg.shuffleId, partition.reduceId, ioe);
} finally {
partition.closeAllFiles();
partition.closeAllFilesAndDeleteIfNeeded(false);
}
}
}
mergeStatuses = new MergeStatuses(msg.shuffleId,
mergeStatuses = new MergeStatuses(msg.shuffleId, msg.shuffleMergeId,
bitmaps.toArray(new RoaringBitmap[bitmaps.size()]), Ints.toArray(reduceIds),
Longs.toArray(sizes));
}
logger.info("Finalized shuffle {} from Application {}_{}.",
msg.shuffleId, msg.appId, msg.appAttemptId);
logger.info("Finalized shuffle {} with shuffleMergeId {} from Application {}_{}.",
msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId);
return mergeStatuses;
}
@ -563,9 +678,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
private void writeBuf(ByteBuffer buf) throws IOException {
while (buf.hasRemaining()) {
long updatedPos = partitionInfo.getDataFilePos() + length;
logger.debug("{} shuffleId {} reduceId {} current pos {} updated pos {}",
partitionInfo.appId, partitionInfo.shuffleId,
partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos);
logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} current pos"
+ " {} updated pos {}", partitionInfo.appId, partitionInfo.shuffleId,
partitionInfo.shuffleMergeId, partitionInfo.reduceId,
partitionInfo.getDataFilePos(), updatedPos);
length += partitionInfo.dataChannel.write(buf, updatedPos);
}
}
@ -631,6 +747,23 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
abortIfNecessary();
}
/**
* If appShuffleMergePartitionsInfo is null or shuffleMergePartitions is set to
* INDETERMINATE_SHUFFLE_FINALIZED or if the reduceId is not in the map then the
* shuffle is already finalized. Therefore the block push is too late. If
* appShuffleMergePartitionsInfo's shuffleMergeId is
* greater than the request shuffleMergeId then it is a stale block push.
*/
private boolean isStaleOrTooLate(
AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo,
int shuffleMergeId,
int reduceId) {
return null == appShuffleMergePartitionsInfo ||
INDETERMINATE_SHUFFLE_FINALIZED == appShuffleMergePartitionsInfo.shuffleMergePartitions ||
appShuffleMergePartitionsInfo.shuffleMergeId > shuffleMergeId ||
!appShuffleMergePartitionsInfo.shuffleMergePartitions.containsKey(reduceId);
}
@Override
public void onData(String streamId, ByteBuffer buf) throws IOException {
// When handling the block data using StreamInterceptor, it can help to reduce the amount
@ -648,14 +781,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
// to disk as well. This way, we avoid having to buffer the entirety of every blocks in
// memory, while still providing the necessary guarantee.
synchronized (partitionInfo) {
Map<Integer, AppShufflePartitionInfo> shufflePartitions =
appShuffleInfo.partitions.get(partitionInfo.shuffleId);
// If the partitionInfo corresponding to (appId, shuffleId, reduceId) is no longer present
// then it means that the shuffle merge has already been finalized. We should thus ignore
// the data and just drain the remaining bytes of this message. This check should be
// placed inside the synchronized block to make sure that checking the key is still
// present and processing the data is atomic.
if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) {
if (isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId),
partitionInfo.shuffleMergeId, partitionInfo.reduceId)) {
deferredBufs = null;
return;
}
@ -668,8 +795,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
return;
}
abortIfNecessary();
logger.trace("{} shuffleId {} reduceId {} onData writable",
partitionInfo.appId, partitionInfo.shuffleId,
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData writable",
partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId,
partitionInfo.reduceId);
if (partitionInfo.getCurrentMapIndex() < 0) {
partitionInfo.setCurrentMapIndex(mapIndex);
@ -690,8 +817,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
throw ioe;
}
} else {
logger.trace("{} shuffleId {} reduceId {} onData deferred",
partitionInfo.appId, partitionInfo.shuffleId,
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData deferred",
partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId,
partitionInfo.reduceId);
// If we cannot write to disk, we buffer the current block chunk in memory so it could
// potentially be written to disk later. We take our best effort without guarantee
@ -725,19 +852,21 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
@Override
public void onComplete(String streamId) throws IOException {
synchronized (partitionInfo) {
logger.trace("{} shuffleId {} reduceId {} onComplete invoked",
partitionInfo.appId, partitionInfo.shuffleId,
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onComplete invoked",
partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId,
partitionInfo.reduceId);
Map<Integer, AppShufflePartitionInfo> shufflePartitions =
appShuffleInfo.partitions.get(partitionInfo.shuffleId);
// When this request initially got to the server, the shuffle merge finalize request
// was not received yet. By the time we finish reading this message, the shuffle merge
// however is already finalized. We should thus respond RpcFailure to the client.
if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) {
// Initially when this request got to the server, the shuffle merge finalize request
// was not received yet or this was the latest stage attempt (or latest shuffleMergeId)
// generating shuffle output for the shuffle ID. By the time we finish reading this
// message, the block request is either stale or too late. We should thus respond
// RpcFailure to the client.
if (isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId),
partitionInfo.shuffleMergeId, partitionInfo.reduceId)) {
deferredBufs = null;
throw new RuntimeException(String.format("Block %s %s", streamId,
ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
throw new RuntimeException(String.format("Block %s is %s", streamId,
ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
}
// Check if we can commit this block
if (allowedToWrite()) {
// Identify duplicate block generated by speculative tasks. We respond success to
@ -800,17 +929,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
logger.debug("Encountered issue when merging {}", streamId, throwable);
}
// Only update partitionInfo if the failure corresponds to a valid request. If the
// request is too late, i.e. received after shuffle merge finalize, #onFailure will
// also be triggered, and we can just ignore. Also, if we couldn't find an opportunity
// to write the block data to disk, we should also ignore here.
// request is too late, i.e. received after shuffle merge finalize or stale block push,
// #onFailure will also be triggered, and we can just ignore. Also, if we couldn't find
// an opportunity to write the block data to disk, we should also ignore here.
if (isWriting) {
synchronized (partitionInfo) {
Map<Integer, AppShufflePartitionInfo> shufflePartitions =
appShuffleInfo.partitions.get(partitionInfo.shuffleId);
if (shufflePartitions != null && shufflePartitions.containsKey(partitionInfo.reduceId)) {
logger.debug("{} shuffleId {} reduceId {} encountered failure",
partitionInfo.appId, partitionInfo.shuffleId,
partitionInfo.reduceId);
if (!isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId),
partitionInfo.shuffleMergeId, partitionInfo.reduceId)) {
logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {}"
+ " encountered failure", partitionInfo.appId, partitionInfo.shuffleId,
partitionInfo.shuffleMergeId, partitionInfo.reduceId);
partitionInfo.setCurrentMapIndex(-1);
}
}
@ -824,12 +952,35 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
}
/**
* Wrapper class to hold merged Shuffle related information for a specific shuffleMergeId
* required for the shuffles of indeterminate stages.
*/
public static class AppShuffleMergePartitionsInfo {
private final int shuffleMergeId;
private final Map<Integer, AppShufflePartitionInfo> shuffleMergePartitions;
public AppShuffleMergePartitionsInfo(
int shuffleMergeId, boolean shuffleFinalized) {
this.shuffleMergeId = shuffleMergeId;
this.shuffleMergePartitions = shuffleFinalized ?
INDETERMINATE_SHUFFLE_FINALIZED : new ConcurrentHashMap<>();
}
@VisibleForTesting
public Map<Integer, AppShufflePartitionInfo> getShuffleMergePartitions() {
return shuffleMergePartitions;
}
}
/** Metadata tracked for an actively merged shuffle partition */
public static class AppShufflePartitionInfo {
private final String appId;
private final int shuffleId;
private final int shuffleMergeId;
private final int reduceId;
private final File dataFile;
// The merged shuffle data file channel
public final FileChannel dataChannel;
// The index file for a particular merged shuffle contains the chunk offsets.
@ -854,6 +1005,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
AppShufflePartitionInfo(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId,
File dataFile,
MergeShuffleFile indexFile,
@ -861,8 +1013,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
Preconditions.checkArgument(appId != null, "app id is null");
this.appId = appId;
this.shuffleId = shuffleId;
this.shuffleMergeId = shuffleMergeId;
this.reduceId = reduceId;
this.dataChannel = new FileOutputStream(dataFile).getChannel();
this.dataFile = dataFile;
this.indexFile = indexFile;
this.metaFile = metaFile;
this.currentMapIndex = -1;
@ -878,8 +1032,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
public void setDataFilePos(long dataFilePos) {
logger.trace("{} shuffleId {} reduceId {} current pos {} update pos {}", appId,
shuffleId, reduceId, this.dataFilePos, dataFilePos);
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} current pos {}"
+ " update pos {}", appId, shuffleId, shuffleMergeId, reduceId, this.dataFilePos,
dataFilePos);
this.dataFilePos = dataFilePos;
}
@ -888,8 +1043,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
void setCurrentMapIndex(int mapIndex) {
logger.trace("{} shuffleId {} reduceId {} updated mapIndex {} current mapIndex {}",
appId, shuffleId, reduceId, currentMapIndex, mapIndex);
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} updated mapIndex {}"
+ " current mapIndex {}", appId, shuffleId, shuffleMergeId, reduceId,
currentMapIndex, mapIndex);
this.currentMapIndex = mapIndex;
}
@ -898,8 +1054,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
void blockMerged(int mapIndex) {
logger.debug("{} shuffleId {} reduceId {} updated merging mapIndex {}", appId,
shuffleId, reduceId, mapIndex);
logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} updated merging mapIndex {}",
appId, shuffleId, shuffleMergeId, reduceId, mapIndex);
mapTracker.add(mapIndex);
chunkTracker.add(mapIndex);
lastMergedMapIndex = mapIndex;
@ -917,8 +1073,9 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
*/
void updateChunkInfo(long chunkOffset, int mapIndex) throws IOException {
try {
logger.trace("{} shuffleId {} reduceId {} index current {} updated {}",
appId, shuffleId, reduceId, this.lastChunkOffset, chunkOffset);
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} index current {}"
+ " updated {}", appId, shuffleId, shuffleMergeId, reduceId,
this.lastChunkOffset, chunkOffset);
if (indexMetaUpdateFailed) {
indexFile.getChannel().position(indexFile.getPos());
}
@ -946,8 +1103,8 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
return;
}
chunkTracker.add(mapIndex);
logger.trace("{} shuffleId {} reduceId {} mapIndex {} write chunk to meta file",
appId, shuffleId, reduceId, mapIndex);
logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} mapIndex {}"
+ " write chunk to meta file", appId, shuffleId, shuffleMergeId, reduceId, mapIndex);
if (indexMetaUpdateFailed) {
metaFile.getChannel().position(metaFile.getPos());
}
@ -980,32 +1137,41 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
metaFile.getChannel().truncate(metaFile.getPos());
}
void closeAllFiles() {
void closeAllFilesAndDeleteIfNeeded(boolean delete) {
try {
if (dataChannel.isOpen()) {
dataChannel.close();
if (delete) {
dataFile.delete();
}
}
} catch (IOException ioe) {
logger.warn("Error closing data channel for {} shuffleId {} reduceId {}",
appId, shuffleId, reduceId);
logger.warn("Error closing data channel for {} shuffleId {} shuffleMergeId {}"
+ " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId);
}
try {
metaFile.close();
if (delete) {
metaFile.delete();
}
} catch (IOException ioe) {
logger.warn("Error closing meta file for {} shuffleId {} reduceId {}",
appId, shuffleId, reduceId);
logger.warn("Error closing meta file for {} shuffleId {} shuffleMergeId {}"
+ " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId);
}
try {
indexFile.close();
if (delete) {
indexFile.delete();
}
} catch (IOException ioe) {
logger.warn("Error closing index file for {} shuffleId {} reduceId {}",
appId, shuffleId, reduceId);
logger.warn("Error closing index file for {} shuffleId {} shuffleMergeId {}"
+ " reduceId {}", appId, shuffleId, shuffleMergeId, reduceId);
}
}
@Override
protected void finalize() throws Throwable {
closeAllFiles();
closeAllFilesAndDeleteIfNeeded(false);
}
@VisibleForTesting
@ -1065,7 +1231,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
private final String appId;
private final int attemptId;
private final AppPathsInfo appPathsInfo;
private final ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions;
/**
* 1. Key tracks shuffleId for an application
* 2. Value tracks the AppShuffleMergePartitionsInfo having shuffleMergeId and
* a Map tracking AppShufflePartitionInfo for all the shuffle partitions.
*/
private final ConcurrentMap<Integer, AppShuffleMergePartitionsInfo> shuffles;
AppShuffleInfo(
String appId,
@ -1074,12 +1245,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
this.appId = appId;
this.attemptId = attemptId;
this.appPathsInfo = appPathsInfo;
partitions = new ConcurrentHashMap<>();
shuffles = new ConcurrentHashMap<>();
}
@VisibleForTesting
public ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> getPartitions() {
return partitions;
public ConcurrentMap<Integer, AppShuffleMergePartitionsInfo> getShuffles() {
return shuffles;
}
/**
@ -1098,29 +1269,37 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
private String generateFileName(
String appId,
int shuffleId,
int shuffleMergeId,
int reduceId) {
return String.format(
"%s_%s_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, appId, shuffleId, reduceId);
"%s_%s_%d_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, appId, shuffleId,
shuffleMergeId, reduceId);
}
public File getMergedShuffleDataFile(
int shuffleId,
int shuffleMergeId,
int reduceId) {
String fileName = String.format("%s.data", generateFileName(appId, shuffleId, reduceId));
String fileName = String.format("%s.data", generateFileName(appId, shuffleId,
shuffleMergeId, reduceId));
return getFile(fileName);
}
public File getMergedShuffleIndexFile(
int shuffleId,
int shuffleMergeId,
int reduceId) {
String indexName = String.format("%s.index", generateFileName(appId, shuffleId, reduceId));
String indexName = String.format("%s.index", generateFileName(appId, shuffleId,
shuffleMergeId, reduceId));
return getFile(indexName);
}
public File getMergedShuffleMetaFile(
int shuffleId,
int shuffleMergeId,
int reduceId) {
String metaName = String.format("%s.meta", generateFileName(appId, shuffleId, reduceId));
String metaName = String.format("%s.meta", generateFileName(appId, shuffleId,
shuffleMergeId, reduceId));
return getFile(metaName);
}
}
@ -1130,18 +1309,21 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
private final FileChannel channel;
private final DataOutputStream dos;
private long pos;
private File file;
@VisibleForTesting
MergeShuffleFile(File file) throws IOException {
FileOutputStream fos = new FileOutputStream(file);
channel = fos.getChannel();
dos = new DataOutputStream(fos);
this.file = file;
}
@VisibleForTesting
MergeShuffleFile(FileChannel channel, DataOutputStream dos) {
this.channel = channel;
this.dos = dos;
this.file = null;
}
private void updatePos(long numBytes) {
@ -1154,6 +1336,16 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
}
}
void delete() throws IOException {
try {
if (null != file) {
file.delete();
}
} finally {
file = null;
}
}
@VisibleForTesting
DataOutputStream getDos() {
return dos;
@ -1169,4 +1361,10 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager {
return pos;
}
}
public static class StaleBlockPushException extends RuntimeException {
public StaleBlockPushException(String message) {
super(message);
}
}
}

View file

@ -37,14 +37,19 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
public final int[] reduceIds;
// The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds.
public final int[][] chunkIds;
// shuffleMergeId is used to uniquely identify merging process of shuffle by
// an indeterminate stage attempt.
public final int shuffleMergeId;
public FetchShuffleBlockChunks(
String appId,
String execId,
int shuffleId,
int shuffleMergeId,
int[] reduceIds,
int[][] chunkIds) {
super(appId, execId, shuffleId);
this.shuffleMergeId = shuffleMergeId;
this.reduceIds = reduceIds;
this.chunkIds = chunkIds;
assert(reduceIds.length == chunkIds.length);
@ -56,6 +61,7 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
@Override
public String toString() {
return toStringHelper()
.append("shuffleMergeId", shuffleMergeId)
.append("reduceIds", Arrays.toString(reduceIds))
.append("chunkIds", Arrays.deepToString(chunkIds))
.toString();
@ -68,13 +74,16 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o;
if (!super.equals(that)) return false;
if (!Arrays.equals(reduceIds, that.reduceIds)) return false;
if (shuffleMergeId != that.shuffleMergeId ||
!Arrays.equals(reduceIds, that.reduceIds)) {
return false;
}
return Arrays.deepEquals(chunkIds, that.chunkIds);
}
@Override
public int hashCode() {
int result = super.hashCode();
int result = super.hashCode() * 31 + shuffleMergeId;
result = 31 * result + Arrays.hashCode(reduceIds);
result = 31 * result + Arrays.deepHashCode(chunkIds);
return result;
@ -89,12 +98,14 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
return super.encodedLength()
+ Encoders.IntArrays.encodedLength(reduceIds)
+ 4 /* encoded length of chunkIds.size() */
+ 4 /* encoded length of shuffleMergeId */
+ encodedLengthOfChunkIds;
}
@Override
public void encode(ByteBuf buf) {
super.encode(buf);
buf.writeInt(shuffleMergeId);
Encoders.IntArrays.encode(buf, reduceIds);
// Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the
// interest of forward compatibility.
@ -117,12 +128,14 @@ public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
String appId = Encoders.Strings.decode(buf);
String execId = Encoders.Strings.decode(buf);
int shuffleId = buf.readInt();
int shuffleMergeId = buf.readInt();
int[] reduceIds = Encoders.IntArrays.decode(buf);
int chunkIdsLen = buf.readInt();
int[][] chunkIds = new int[chunkIdsLen][];
for (int i = 0; i < chunkIdsLen; i++) {
chunkIds[i] = Encoders.IntArrays.decode(buf);
}
return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds);
return new FetchShuffleBlockChunks(appId, execId, shuffleId, shuffleMergeId, reduceIds,
chunkIds);
}
}

View file

@ -34,14 +34,17 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
public final String appId;
public final int appAttemptId;
public final int shuffleId;
public final int shuffleMergeId;
public FinalizeShuffleMerge(
String appId,
int appAttemptId,
int shuffleId) {
int shuffleId,
int shuffleMergeId) {
this.appId = appId;
this.appAttemptId = appAttemptId;
this.shuffleId = shuffleId;
this.shuffleMergeId = shuffleMergeId;
}
@Override
@ -51,7 +54,7 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
@Override
public int hashCode() {
return Objects.hashCode(appId, appAttemptId, shuffleId);
return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId);
}
@Override
@ -60,6 +63,7 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
.append("appId", appId)
.append("attemptId", appAttemptId)
.append("shuffleId", shuffleId)
.append("shuffleMergeId", shuffleMergeId)
.toString();
}
@ -69,14 +73,15 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
FinalizeShuffleMerge o = (FinalizeShuffleMerge) other;
return Objects.equal(appId, o.appId)
&& appAttemptId == o.appAttemptId
&& shuffleId == o.shuffleId;
&& shuffleId == o.shuffleId
&& shuffleMergeId == o.shuffleMergeId;
}
return false;
}
@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId) + 4 + 4;
return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4;
}
@Override
@ -84,12 +89,14 @@ public class FinalizeShuffleMerge extends BlockTransferMessage {
Encoders.Strings.encode(buf, appId);
buf.writeInt(appAttemptId);
buf.writeInt(shuffleId);
buf.writeInt(shuffleMergeId);
}
public static FinalizeShuffleMerge decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
int attemptId = buf.readInt();
int shuffleId = buf.readInt();
return new FinalizeShuffleMerge(appId, attemptId, shuffleId);
int shuffleMergeId = buf.readInt();
return new FinalizeShuffleMerge(appId, attemptId, shuffleId, shuffleMergeId);
}
}

View file

@ -40,6 +40,11 @@ import org.apache.spark.network.protocol.Encoders;
public class MergeStatuses extends BlockTransferMessage {
/** Shuffle ID **/
public final int shuffleId;
/**
* shuffleMergeId is used to uniquely identify merging process of shuffle by
* an indeterminate stage attempt.
*/
public final int shuffleMergeId;
/**
* Array of bitmaps tracking the set of mapper partition blocks merged for each
* reducer partition
@ -55,10 +60,12 @@ public class MergeStatuses extends BlockTransferMessage {
public MergeStatuses(
int shuffleId,
int shuffleMergeId,
RoaringBitmap[] bitmaps,
int[] reduceIds,
long[] sizes) {
this.shuffleId = shuffleId;
this.shuffleMergeId = shuffleMergeId;
this.bitmaps = bitmaps;
this.reduceIds = reduceIds;
this.sizes = sizes;
@ -71,7 +78,8 @@ public class MergeStatuses extends BlockTransferMessage {
@Override
public int hashCode() {
int objectHashCode = Objects.hashCode(shuffleId);
int objectHashCode = Objects.hashCode(shuffleId) * 41 +
Objects.hashCode(shuffleMergeId);
return (objectHashCode * 41 + Arrays.hashCode(reduceIds) * 41
+ Arrays.hashCode(bitmaps) * 41 + Arrays.hashCode(sizes));
}
@ -80,6 +88,7 @@ public class MergeStatuses extends BlockTransferMessage {
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
.append("shuffleId", shuffleId)
.append("shuffleMergeId", shuffleMergeId)
.append("reduceId size", reduceIds.length)
.toString();
}
@ -89,6 +98,7 @@ public class MergeStatuses extends BlockTransferMessage {
if (other != null && other instanceof MergeStatuses) {
MergeStatuses o = (MergeStatuses) other;
return Objects.equal(shuffleId, o.shuffleId)
&& Objects.equal(shuffleMergeId, o.shuffleMergeId)
&& Arrays.equals(bitmaps, o.bitmaps)
&& Arrays.equals(reduceIds, o.reduceIds)
&& Arrays.equals(sizes, o.sizes);
@ -98,7 +108,7 @@ public class MergeStatuses extends BlockTransferMessage {
@Override
public int encodedLength() {
return 4 // int
return 4 + 4 // shuffleId and shuffleMergeId
+ Encoders.BitmapArrays.encodedLength(bitmaps)
+ Encoders.IntArrays.encodedLength(reduceIds)
+ Encoders.LongArrays.encodedLength(sizes);
@ -107,6 +117,7 @@ public class MergeStatuses extends BlockTransferMessage {
@Override
public void encode(ByteBuf buf) {
buf.writeInt(shuffleId);
buf.writeInt(shuffleMergeId);
Encoders.BitmapArrays.encode(buf, bitmaps);
Encoders.IntArrays.encode(buf, reduceIds);
Encoders.LongArrays.encode(buf, sizes);
@ -114,9 +125,10 @@ public class MergeStatuses extends BlockTransferMessage {
public static MergeStatuses decode(ByteBuf buf) {
int shuffleId = buf.readInt();
int shuffleMergeId = buf.readInt();
RoaringBitmap[] bitmaps = Encoders.BitmapArrays.decode(buf);
int[] reduceIds = Encoders.IntArrays.decode(buf);
long[] sizes = Encoders.LongArrays.decode(buf);
return new MergeStatuses(shuffleId, bitmaps, reduceIds, sizes);
return new MergeStatuses(shuffleId, shuffleMergeId, bitmaps, reduceIds, sizes);
}
}

View file

@ -37,6 +37,7 @@ public class PushBlockStream extends BlockTransferMessage {
public final String appId;
public final int appAttemptId;
public final int shuffleId;
public final int shuffleMergeId;
public final int mapIndex;
public final int reduceId;
// Similar to the chunkIndex in StreamChunkId, indicating the index of a block in a batch of
@ -47,12 +48,14 @@ public class PushBlockStream extends BlockTransferMessage {
String appId,
int appAttemptId,
int shuffleId,
int shuffleMergeId,
int mapIndex,
int reduceId,
int index) {
this.appId = appId;
this.appAttemptId = appAttemptId;
this.shuffleId = shuffleId;
this.shuffleMergeId = shuffleMergeId;
this.mapIndex = mapIndex;
this.reduceId = reduceId;
this.index = index;
@ -65,7 +68,8 @@ public class PushBlockStream extends BlockTransferMessage {
@Override
public int hashCode() {
return Objects.hashCode(appId, appAttemptId, shuffleId, mapIndex , reduceId, index);
return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId, mapIndex , reduceId,
index);
}
@Override
@ -74,6 +78,7 @@ public class PushBlockStream extends BlockTransferMessage {
.append("appId", appId)
.append("attemptId", appAttemptId)
.append("shuffleId", shuffleId)
.append("shuffleMergeId", shuffleMergeId)
.append("mapIndex", mapIndex)
.append("reduceId", reduceId)
.append("index", index)
@ -87,6 +92,7 @@ public class PushBlockStream extends BlockTransferMessage {
return Objects.equal(appId, o.appId)
&& appAttemptId == o.appAttemptId
&& shuffleId == o.shuffleId
&& shuffleMergeId == o.shuffleMergeId
&& mapIndex == o.mapIndex
&& reduceId == o.reduceId
&& index == o.index;
@ -96,7 +102,7 @@ public class PushBlockStream extends BlockTransferMessage {
@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4 + 4;
return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4 + 4 + 4;
}
@Override
@ -104,6 +110,7 @@ public class PushBlockStream extends BlockTransferMessage {
Encoders.Strings.encode(buf, appId);
buf.writeInt(appAttemptId);
buf.writeInt(shuffleId);
buf.writeInt(shuffleMergeId);
buf.writeInt(mapIndex);
buf.writeInt(reduceId);
buf.writeInt(index);
@ -113,9 +120,11 @@ public class PushBlockStream extends BlockTransferMessage {
String appId = Encoders.Strings.decode(buf);
int attemptId = buf.readInt();
int shuffleId = buf.readInt();
int shuffleMergeId = buf.readInt();
int mapIdx = buf.readInt();
int reduceId = buf.readInt();
int index = buf.readInt();
return new PushBlockStream(appId, attemptId, shuffleId, mapIdx, reduceId, index);
return new PushBlockStream(appId, attemptId, shuffleId, shuffleMergeId, mapIdx, reduceId,
index);
}
}

View file

@ -29,23 +29,31 @@ import static org.junit.Assert.*;
public class ErrorHandlerSuite {
@Test
public void testPushErrorRetry() {
ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
assertFalse(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
assertFalse(handler.shouldRetryError(new RuntimeException(new ConnectException())));
assertTrue(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
public void testErrorRetry() {
ErrorHandler.BlockPushErrorHandler pushHandler = new ErrorHandler.BlockPushErrorHandler();
assertFalse(pushHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))));
assertFalse(pushHandler.shouldRetryError(new RuntimeException(new ConnectException())));
assertTrue(pushHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
assertTrue(handler.shouldRetryError(new Throwable()));
assertTrue(pushHandler.shouldRetryError(new Throwable()));
ErrorHandler.BlockFetchErrorHandler fetchHandler = new ErrorHandler.BlockFetchErrorHandler();
assertFalse(fetchHandler.shouldRetryError(new RuntimeException(
ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)));
}
@Test
public void testPushErrorLogging() {
ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
public void testErrorLogging() {
ErrorHandler.BlockPushErrorHandler pushHandler = new ErrorHandler.BlockPushErrorHandler();
assertFalse(pushHandler.shouldLogError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))));
assertFalse(pushHandler.shouldLogError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
assertTrue(handler.shouldLogError(new Throwable()));
assertTrue(pushHandler.shouldLogError(new Throwable()));
ErrorHandler.BlockFetchErrorHandler fetchHandler = new ErrorHandler.BlockFetchErrorHandler();
assertFalse(fetchHandler.shouldLogError(new RuntimeException(
ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)));
}
}

View file

@ -243,9 +243,9 @@ public class ExternalBlockHandlerSuite {
public void testFinalizeShuffleMerge() throws IOException {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 1, 0);
FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 1, 0, 0);
RoaringBitmap bitmap = RoaringBitmap.bitmapOf(0, 1, 2);
MergeStatuses statuses = new MergeStatuses(0, new RoaringBitmap[]{bitmap},
MergeStatuses statuses = new MergeStatuses(0, 0, new RoaringBitmap[]{bitmap},
new int[]{3}, new long[]{30});
when(mergedShuffleManager.finalizeShuffleMerge(req)).thenReturn(statuses);
@ -269,22 +269,22 @@ public class ExternalBlockHandlerSuite {
@Test
public void testFetchMergedBlocksMeta() {
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn(
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 0)).thenReturn(
new MergedBlockMeta(1, mock(ManagedBuffer.class)));
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn(
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 1)).thenReturn(
new MergedBlockMeta(3, mock(ManagedBuffer.class)));
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn(
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0, 2)).thenReturn(
new MergedBlockMeta(5, mock(ManagedBuffer.class)));
int[] expectedCount = new int[]{1, 3, 5};
String appId = "app0";
long requestId = 0L;
for (int reduceId = 0; reduceId < 3; reduceId++) {
MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId);
MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, 0, reduceId);
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
handler.getMergedBlockMetaReqHandler()
.receiveMergeBlockMetaReq(client, req, callback);
verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId);
verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, 0, reduceId);
ArgumentCaptor<Integer> numChunksResponse = ArgumentCaptor.forClass(Integer.class);
ArgumentCaptor<ManagedBuffer> chunkBitmapResponse =
@ -313,12 +313,12 @@ public class ExternalBlockHandlerSuite {
if (useOpenBlocks) {
OpenBlocks openBlocks =
new OpenBlocks("app0", "exec1",
new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0",
"shuffleChunk_0_1_1"});
new String[] {"shuffleChunk_0_0_0_0", "shuffleChunk_0_0_0_1", "shuffleChunk_0_0_1_0",
"shuffleChunk_0_0_1_1"});
buffer = openBlocks.toByteBuffer();
} else {
FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks(
"app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}});
"app0", "exec1", 0, 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}});
buffer = fetchChunks.toByteBuffer();
}
ManagedBuffer[][] buffers = new ManagedBuffer[][] {
@ -334,7 +334,7 @@ public class ExternalBlockHandlerSuite {
for (int reduceId = 0; reduceId < 2; reduceId++) {
for (int chunkId = 0; chunkId < 2; chunkId++) {
when(mergedShuffleManager.getMergedBlockData(
"app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]);
"app0", 0, 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]);
}
}
handler.receive(client, buffer, callback);
@ -356,11 +356,12 @@ public class ExternalBlockHandlerSuite {
}
}
assertFalse(bufferIter.hasNext());
verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt());
verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt(),
anyInt());
verify(blockResolver, never()).getBlockData(
anyString(), anyString(), anyInt(), anyInt(), anyInt());
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0);
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1);
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0, 0);
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0, 1);
// Verify open block request latency metrics
Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler)

View file

@ -246,51 +246,49 @@ public class OneForOneBlockFetcherSuite {
@Test
public void testShuffleBlockChunksFetch() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
blocks.put("shuffleChunk_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shuffleChunk_0_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
blocks.put("shuffleChunk_0_0_0_2",
new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[] { 0 },
new int[][] {{ 0, 1, 2 }}), conf);
for (int i = 0; i < 3; i ++) {
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i,
blocks.get("shuffleChunk_0_0_" + i));
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_" + i,
blocks.get("shuffleChunk_0_0_0_" + i));
}
}
@Test
public void testShuffleBlockChunkFetchFailure() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shuffleChunk_0_0_1", null);
blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
blocks.put("shuffleChunk_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shuffleChunk_0_0_0_1", null);
blocks.put("shuffleChunk_0_0_0_2",
new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}),
new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[]{0}, new int[][]{{0, 1, 2}}),
conf);
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0",
blocks.get("shuffleChunk_0_0_0"));
verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any());
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2",
blocks.get("shuffleChunk_0_0_2"));
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_0",
blocks.get("shuffleChunk_0_0_0_0"));
verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_0_1"), any());
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0_2",
blocks.get("shuffleChunk_0_0_0_2"));
}
@Test
public void testInvalidShuffleBlockIds() {
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
new String[]{"shuffle_0_0"},
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new int[][] {{ 0 }}), conf));
new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 },
new int[][] {{ 0 }}, false), conf));
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
new String[]{"shuffleChunk_0_0_0_0_0"},
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new int[][] {{ 0 }}), conf));
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
new String[]{"shuffleChunk_0_0_0_0"},
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new FetchShuffleBlockChunks("app-id", "exec-id", 0, 0, new int[] { 0 },
new int[][] {{ 0 }}), conf));
}

View file

@ -45,77 +45,78 @@ public class OneForOneBlockPusherSuite {
@Test
public void testPushOne() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockPushingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0)));
Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0)));
verify(listener).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
verify(listener).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
}
@Test
public void testPushThree() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
blocks.put("shufflePush_0_2_0", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shufflePush_0_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
blocks.put("shufflePush_0_0_2_0",
new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockPushingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id",0, 0, 0, 0, 0),
new PushBlockStream("app-id", 0, 0, 1, 0, 1),
new PushBlockStream("app-id", 0, 0, 2, 0, 2)));
Arrays.asList(new PushBlockStream("app-id",0, 0, 0, 0, 0, 0),
new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1),
new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2)));
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_1_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_2_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_1_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_2_0"), any());
}
@Test
public void testServerFailures() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shufflePush_0_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
blocks.put("shufflePush_0_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockPushingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0),
new PushBlockStream("app-id", 0, 0, 1, 0, 1),
new PushBlockStream("app-id", 0, 0, 2, 0, 2)));
Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0),
new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1),
new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2)));
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_1_0"), any());
verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_2_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_1_0"), any());
verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_2_0"), any());
}
@Test
public void testHandlingRetriableFailures() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shufflePush_0_1_0", null);
blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
blocks.put("shufflePush_0_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shufflePush_0_0_1_0", null);
blocks.put("shufflePush_0_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockPushingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0),
new PushBlockStream("app-id", 0, 0, 1, 0, 1),
new PushBlockStream("app-id", 0, 0, 2, 0, 2)));
Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0, 0, 0),
new PushBlockStream("app-id", 0, 0, 0, 1, 0, 1),
new PushBlockStream("app-id", 0, 0, 0, 2, 0, 2)));
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0"), any());
verify(listener, times(0)).onBlockPushSuccess(not(eq("shufflePush_0_0_0")), any());
verify(listener, times(0)).onBlockPushFailure(eq("shufflePush_0_0_0"), any());
verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_1_0"), any());
verify(listener, times(2)).onBlockPushFailure(eq("shufflePush_0_2_0"), any());
verify(listener, times(1)).onBlockPushSuccess(eq("shufflePush_0_0_0_0"), any());
verify(listener, times(0)).onBlockPushSuccess(not(eq("shufflePush_0_0_0_0")), any());
verify(listener, times(0)).onBlockPushFailure(eq("shufflePush_0_0_0_0"), any());
verify(listener, times(1)).onBlockPushFailure(eq("shufflePush_0_0_1_0"), any());
verify(listener, times(2)).onBlockPushFailure(eq("shufflePush_0_0_2_0"), any());
}
/**
@ -147,7 +148,7 @@ public class OneForOneBlockPusherSuite {
+ ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX));
} else {
callback.onFailure(new RuntimeException("Quick fail " + entry.getKey()
+ ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
+ ErrorHandler.BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX));
}
assertEquals(msgIterator.next(), message);
return null;

View file

@ -29,10 +29,10 @@ public class FetchShuffleBlockChunksSuite {
@Test
public void testFetchShuffleBlockChunksEncodeDecode() {
FetchShuffleBlockChunks shuffleBlockChunks =
new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}});
new FetchShuffleBlockChunks("app0", "exec1", 0, 0, new int[] {0}, new int[][] {{0, 1}});
Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks());
int len = shuffleBlockChunks.encodedLength();
Assert.assertEquals(45, len);
Assert.assertEquals(49, len);
ByteBuf buf = Unpooled.buffer(len);
shuffleBlockChunks.encode(buf);

View file

@ -20,6 +20,7 @@ package org.apache.spark
import scala.reflect.ClassTag
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
@ -78,7 +79,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor)
extends Dependency[Product2[K, V]] {
extends Dependency[Product2[K, V]] with Logging {
if (mapSideCombine) {
require(aggregator.isDefined, "Map-side combine without Aggregator specified!")
@ -101,10 +102,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
// By default, shuffle merge is enabled for ShuffleDependency if push based shuffle
// is enabled
private[this] var _shuffleMergeEnabled =
Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) &&
// TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages
!rdd.isBarrier()
private[this] var _shuffleMergeEnabled = canShuffleMergeBeEnabled()
private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = {
_shuffleMergeEnabled = shuffleMergeEnabled
@ -124,6 +122,14 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
*/
private[this] var _shuffleMergedFinalized: Boolean = false
/**
* shuffleMergeId is used to uniquely identify merging process of shuffle
* by an indeterminate stage attempt.
*/
private[this] var _shuffleMergeId: Int = 0
def shuffleMergeId: Int = _shuffleMergeId
def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
if (mergerLocs != null) {
this.mergerLocs = mergerLocs
@ -150,6 +156,22 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
}
}
def newShuffleMergeState(): Unit = {
_shuffleMergedFinalized = false
mergerLocs = Nil
_shuffleMergeId += 1
}
private def canShuffleMergeBeEnabled(): Boolean = {
val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf)
if (isPushShuffleEnabled && rdd.isBarrier()) {
logWarning("Push-based shuffle is currently not supported for barrier stages")
}
isPushShuffleEnabled &&
// TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages
!rdd.isBarrier()
}
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}

View file

@ -38,7 +38,7 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
/**
@ -1450,11 +1450,10 @@ private[spark] object MapOutputTracker extends Logging {
val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) {
// If MergeStatus is available for the given partition, add location of the
// pre-merged shuffle partition for this partition ID. Here we create a
// ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is
// a merged shuffle block.
// ShuffleMergedBlockId to indicate this is a merged shuffle block.
splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize,
SHUFFLE_PUSH_MAP_ID))
((ShuffleMergedBlockId(shuffleId, mergeStatus.shuffleMergeId, partId),
mergeStatus.totalSize, SHUFFLE_PUSH_MAP_ID))
// For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
// shuffle partition blocks, fetch the original map produced shuffle partition blocks
val mapStatusesWithIndex = mapStatuses.zipWithIndex

View file

@ -1323,9 +1323,8 @@ private[spark] class DAGScheduler(
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
// TODO: SPARK-32923: Clean all push-based shuffle metadata like merge enabled and
// TODO: finalized as we are clearing all the merge results.
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
}
@ -2057,7 +2056,7 @@ private[spark] class DAGScheduler(
// TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled
// TODO: during shuffleMergeFinalizeWaitSec
shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
shuffleServiceLoc.port, shuffleId,
shuffleServiceLoc.port, shuffleId, stage.shuffleDep.shuffleMergeId,
new MergeFinalizerListener {
override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = {
assert(shuffleId == statuses.shuffleId)

View file

@ -43,14 +43,17 @@ import org.apache.spark.util.Utils
*/
private[spark] class MergeStatus(
private[this] var loc: BlockManagerId,
private[this] var _shuffleMergeId: Int,
private[this] var mapTracker: RoaringBitmap,
private[this] var size: Long)
extends Externalizable with ShuffleOutputStatus {
protected def this() = this(null, null, -1) // For deserialization only
protected def this() = this(null, -1, null, -1) // For deserialization only
def location: BlockManagerId = loc
def shuffleMergeId: Int = _shuffleMergeId
def totalSize: Long = size
def tracker: RoaringBitmap = mapTracker
@ -73,12 +76,14 @@ private[spark] class MergeStatus(
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(_shuffleMergeId)
mapTracker.writeExternal(out)
out.writeLong(size)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
_shuffleMergeId = in.readInt()
mapTracker = new RoaringBitmap()
mapTracker.readExternal(in)
size = in.readLong()
@ -100,14 +105,20 @@ private[spark] object MergeStatus {
assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length &&
mergeStatuses.bitmaps.length == mergeStatuses.sizes.length)
val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port)
val shuffleMergeId = mergeStatuses.shuffleMergeId
mergeStatuses.bitmaps.zipWithIndex.map {
case (bitmap, index) =>
val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index))
val mergeStatus = new MergeStatus(mergerLoc, shuffleMergeId, bitmap,
mergeStatuses.sizes(index))
(mergeStatuses.reduceIds(index), mergeStatus)
}
}
def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = {
new MergeStatus(loc, bitmap, size)
def apply(
loc: BlockManagerId,
shuffleMergeId: Int,
bitmap: RoaringBitmap,
size: Long): MergeStatus = {
new MergeStatus(loc, shuffleMergeId, bitmap, size)
}
}

View file

@ -116,28 +116,31 @@ private[spark] class IndexShuffleBlockResolver(
private def getMergedBlockDataFile(
appId: String,
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
dirs: Option[Array[String]] = None): File = {
blockManager.diskBlockManager.getMergedShuffleFile(
ShuffleMergedDataBlockId(appId, shuffleId, reduceId), dirs)
ShuffleMergedDataBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs)
}
private def getMergedBlockIndexFile(
appId: String,
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
dirs: Option[Array[String]] = None): File = {
blockManager.diskBlockManager.getMergedShuffleFile(
ShuffleMergedIndexBlockId(appId, shuffleId, reduceId), dirs)
ShuffleMergedIndexBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs)
}
private def getMergedBlockMetaFile(
appId: String,
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
dirs: Option[Array[String]] = None): File = {
blockManager.diskBlockManager.getMergedShuffleFile(
ShuffleMergedMetaBlockId(appId, shuffleId, reduceId), dirs)
ShuffleMergedMetaBlockId(appId, shuffleId, shuffleMergeId, reduceId), dirs)
}
/**
@ -466,11 +469,13 @@ private[spark] class IndexShuffleBlockResolver(
* knows how to consume local merged shuffle file as multiple chunks.
*/
override def getMergedBlockData(
blockId: ShuffleBlockId,
blockId: ShuffleMergedBlockId,
dirs: Option[Array[String]]): Seq[ManagedBuffer] = {
val indexFile =
getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.shuffleMergeId,
blockId.reduceId, dirs)
val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId,
blockId.shuffleMergeId, blockId.reduceId, dirs)
// Load all the indexes in order to identify all chunks in the specified merged shuffle file.
val size = indexFile.length.toInt
val offsets = Utils.tryWithResource {
@ -493,13 +498,15 @@ private[spark] class IndexShuffleBlockResolver(
* This is only used for reading local merged block meta data.
*/
override def getMergedBlockMeta(
blockId: ShuffleBlockId,
blockId: ShuffleMergedBlockId,
dirs: Option[Array[String]]): MergedBlockMeta = {
val indexFile =
getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId,
blockId.shuffleMergeId, blockId.reduceId, dirs)
val size = indexFile.length.toInt
val numChunks = (size / 8) - 1
val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId, blockId.reduceId, dirs)
val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId,
blockId.shuffleMergeId, blockId.reduceId, dirs)
val chunkBitMaps = new FileSegmentManagedBuffer(transportConf, metaFile, 0L, metaFile.length)
new MergedBlockMeta(numChunks, chunkBitMaps)
}

View file

@ -69,7 +69,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
new BlockPushErrorHandler() {
// For a connection exception against a particular host, we will stop pushing any
// blocks to just that host and continue push blocks to other hosts. So, here push of
// all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore.
// all blocks will only stop when it is "Too Late" or "Invalid Block push.
// Also see updateStateAndCheckIfPushMore.
override def shouldRetryError(t: Throwable): Boolean = {
// If it is a FileNotFoundException originating from the client while pushing the shuffle
// blocks to the server, then we stop pushing all the blocks because this indicates the
@ -77,8 +78,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
if (t.getCause != null && t.getCause.isInstanceOf[FileNotFoundException]) {
return false
}
// If the block is too late, there is no need to retry it
!Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
val errorStackTraceString = Throwables.getStackTraceAsString(t)
// If the block is too late or the invalid block push, there is no need to retry it
!errorStackTraceString.contains(
BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)
}
}
}
@ -99,8 +102,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
mapIndex: Int): Unit = {
val numPartitions = dep.partitioner.numPartitions
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile,
partitionLengths, dep.getMergerLocs, transportConf)
val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId,
dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs, transportConf)
// Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
// time won't be pushing the same ranges of shuffle partitions.
pushRequests ++= Utils.randomize(requests)
@ -335,6 +338,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
* @param numPartitions number of shuffle partitions in the shuffle file
* @param partitionId map index of the current mapper
* @param shuffleId shuffleId of current shuffle
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param dataFile shuffle data file
* @param partitionLengths array of sizes of blocks in the shuffle data file
* @param mergerLocs target locations to push blocks to
@ -347,6 +352,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
numPartitions: Int,
partitionId: Int,
shuffleId: Int,
shuffleMergeId: Int,
dataFile: File,
partitionLengths: Array[Long],
mergerLocs: Seq[BlockManagerId],
@ -361,7 +367,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
for (reduceId <- 0 until numPartitions) {
val blockSize = partitionLengths(reduceId)
logDebug(
s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize")
s"Block ${ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId,
reduceId)} is of size $blockSize")
// Skip 0-length blocks and blocks that are large enough
if (blockSize > 0) {
val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers),
@ -394,7 +401,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
// Only push blocks under the size limit
if (blockSize <= maxBlockSizeToPush) {
val blockSizeInt = blockSize.toInt
blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt))
blocks += ((ShufflePushBlockId(shuffleId, shuffleMergeId, partitionId,
reduceId), blockSizeInt))
// Only update currentReqOffset if the current block is the first in the request
if (currentReqOffset == -1) {
currentReqOffset = offset

View file

@ -19,7 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.MergedBlockMeta
import org.apache.spark.storage.{BlockId, ShuffleBlockId}
import org.apache.spark.storage.{BlockId, ShuffleMergedBlockId}
private[spark]
/**
@ -44,12 +44,16 @@ trait ShuffleBlockResolver {
/**
* Retrieve the data for the specified merged shuffle block as multiple chunks.
*/
def getMergedBlockData(blockId: ShuffleBlockId, dirs: Option[Array[String]]): Seq[ManagedBuffer]
def getMergedBlockData(
blockId: ShuffleMergedBlockId,
dirs: Option[Array[String]]): Seq[ManagedBuffer]
/**
* Retrieve the meta data for the specified merged shuffle block.
*/
def getMergedBlockMeta(blockId: ShuffleBlockId, dirs: Option[Array[String]]): MergedBlockMeta
def getMergedBlockMeta(
blockId: ShuffleMergedBlockId,
dirs: Option[Array[String]]): MergedBlockMeta
def stop(): Unit
}

View file

@ -77,9 +77,11 @@ case class ShuffleBlockBatchId(
@DeveloperApi
case class ShuffleBlockChunkId(
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
chunkId: Int) extends BlockId {
override def name: String = "shuffleChunk_" + shuffleId + "_" + reduceId + "_" + chunkId
override def name: String =
"shuffleChunk_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + "_" + chunkId
}
@DeveloperApi
@ -100,15 +102,34 @@ case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) ex
@Since("3.2.0")
@DeveloperApi
case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId
case class ShufflePushBlockId(
shuffleId: Int,
shuffleMergeId: Int,
mapIndex: Int,
reduceId: Int) extends BlockId {
override def name: String = "shufflePush_" + shuffleId + "_" +
shuffleMergeId + "_" + mapIndex + "_" + reduceId + ""
}
@Since("3.2.0")
@DeveloperApi
case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: Int) extends BlockId {
case class ShuffleMergedBlockId(
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int) extends BlockId {
override def name: String = "shuffleMerged_" + shuffleId + "_" +
shuffleMergeId + "_" + reduceId
}
@Since("3.2.0")
@DeveloperApi
case class ShuffleMergedDataBlockId(
appId: String,
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int) extends BlockId {
override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
appId + "_" + shuffleId + "_" + reduceId + ".data"
appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".data"
}
@Since("3.2.0")
@ -116,9 +137,10 @@ case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: Int
case class ShuffleMergedIndexBlockId(
appId: String,
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int) extends BlockId {
override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
appId + "_" + shuffleId + "_" + reduceId + ".index"
appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".index"
}
@Since("3.2.0")
@ -126,9 +148,10 @@ case class ShuffleMergedIndexBlockId(
case class ShuffleMergedMetaBlockId(
appId: String,
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int) extends BlockId {
override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
appId + "_" + shuffleId + "_" + reduceId + ".meta"
appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".meta"
}
@DeveloperApi
@ -172,11 +195,15 @@ object BlockId {
val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_MERGED_DATA = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_MERGED_INDEX = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r
val SHUFFLE_MERGED_META = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r
val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_MERGED = "shuffleMerged_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_MERGED_DATA =
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_MERGED_INDEX =
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).index".r
val SHUFFLE_MERGED_META =
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).meta".r
val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
@ -195,16 +222,22 @@ object BlockId {
ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) =>
ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt)
case SHUFFLE_MERGED_DATA(appId, shuffleId, reduceId) =>
ShuffleMergedDataBlockId(appId, shuffleId.toInt, reduceId.toInt)
case SHUFFLE_MERGED_INDEX(appId, shuffleId, reduceId) =>
ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt)
case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) =>
ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt)
case SHUFFLE_CHUNK(shuffleId, reduceId, chunkId) =>
ShuffleBlockChunkId(shuffleId.toInt, reduceId.toInt, chunkId.toInt)
case SHUFFLE_PUSH(shuffleId, shuffleMergeId, mapIndex, reduceId) =>
ShufflePushBlockId(shuffleId.toInt, shuffleMergeId.toInt, mapIndex.toInt,
reduceId.toInt)
case SHUFFLE_MERGED(shuffleId, shuffleMergeId, reduceId) =>
ShuffleMergedBlockId(shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt)
case SHUFFLE_MERGED_DATA(appId, shuffleId, shuffleMergeId, reduceId) =>
ShuffleMergedDataBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt)
case SHUFFLE_MERGED_INDEX(appId, shuffleId, shuffleMergeId, reduceId) =>
ShuffleMergedIndexBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt,
reduceId.toInt)
case SHUFFLE_MERGED_META(appId, shuffleId, shuffleMergeId, reduceId) =>
ShuffleMergedMetaBlockId(appId, shuffleId.toInt, shuffleMergeId.toInt,
reduceId.toInt)
case SHUFFLE_CHUNK(shuffleId, shuffleMergeId, reduceId, chunkId) =>
ShuffleBlockChunkId(shuffleId.toInt, shuffleMergeId.toInt, reduceId.toInt,
chunkId.toInt)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>

View file

@ -748,7 +748,7 @@ private[spark] class BlockManager(
* which will be memory efficient when performing certain operations.
*/
def getLocalMergedBlockData(
blockId: ShuffleBlockId,
blockId: ShuffleMergedBlockId,
dirs: Array[String]): Seq[ManagedBuffer] = {
shuffleManager.shuffleBlockResolver.getMergedBlockData(blockId, Some(dirs))
}
@ -757,7 +757,7 @@ private[spark] class BlockManager(
* Get the local merged shuffle block meta data for the given block ID.
*/
def getLocalMergedBlockMeta(
blockId: ShuffleBlockId,
blockId: ShuffleMergedBlockId,
dirs: Array[String]): MergedBlockMeta = {
shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId, Some(dirs))
}

View file

@ -110,13 +110,14 @@ private class PushBasedFetchHelper(
*/
def createChunkBlockInfosFromMetaResponse(
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
blockSize: Long,
bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
val approxChunkSize = blockSize / bitmaps.length
val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
for (i <- bitmaps.indices) {
val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i)
chunksMetaMap.put(blockChunkId, bitmaps(i))
logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
@ -134,37 +135,41 @@ private class PushBasedFetchHelper(
def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
val sizeMap = req.blocks.map {
case FetchBlockInfo(blockId, size, _) =>
val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId]
((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
}.toMap
val address = req.address
val mergedBlocksMetaListener = new MergedBlocksMetaListener {
override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId) " +
s"from ${req.address.host}:${req.address.port}")
override def onSuccess(shuffleId: Int, shuffleMergeId: Int, reduceId: Int,
meta: MergedBlockMeta): Unit = {
logInfo(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," +
s" $reduceId) from ${req.address.host}:${req.address.port}")
try {
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address))
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, shuffleMergeId,
reduceId, sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address))
} catch {
case exception: Exception =>
logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
s"$shuffleMergeId, $reduceId) from" +
s" ${req.address.host}:${req.address.port}", exception)
iterator.addToResultsQueue(
PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId,
address))
}
}
override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
override def onFailure(shuffleId: Int, shuffleMergeId: Int, reduceId: Int,
exception: Throwable): Unit = {
logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
s"from ${req.address.host}:${req.address.port}", exception)
iterator.addToResultsQueue(
PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address))
}
}
req.blocks.foreach { block =>
val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId]
shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
shuffleBlockId.reduceId, mergedBlocksMetaListener)
shuffleBlockId.shuffleMergeId, shuffleBlockId.reduceId, mergedBlocksMetaListener)
}
}
@ -241,11 +246,11 @@ private class PushBasedFetchHelper(
localDirs: Array[String],
blockManagerId: BlockManagerId): Unit = {
try {
val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId]
val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
iterator.addToResultsQueue(PushMergedLocalMetaFetchResult(
shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(),
localDirs))
shuffleBlockId.shuffleId, shuffleBlockId.shuffleMergeId,
shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(), localDirs))
} catch {
case e: Exception =>
// If we see an exception with reading a push-merged-local meta, we fallback to
@ -283,13 +288,13 @@ private class PushBasedFetchHelper(
def initiateFallbackFetchForPushMergedBlock(
blockId: BlockId,
address: BlockManagerId): Unit = {
assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId")
// Increase the blocks processed since we will process another block in the next iteration of
// the while loop in ShuffleBlockFetcherIterator.next().
val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
blockId match {
case shuffleBlockId: ShuffleBlockId =>
case shuffleBlockId: ShuffleMergedBlockId =>
iterator.decreaseNumBlocksToFetch(1)
mapOutputTracker.getMapSizesForMergeResult(
shuffleBlockId.shuffleId, shuffleBlockId.reduceId)

View file

@ -490,14 +490,14 @@ final class ShuffleBlockFetcherIterator(
// Either all blocks are push-merged blocks, shuffle chunks, or original blocks.
// Based on these types, we decide to do batch fetch and create FetchRequests with
// forMergedMetas set.
case ShuffleBlockChunkId(_, _, _) =>
case ShuffleBlockChunkId(_, _, _, _) =>
if (curRequestSize >= targetRemoteRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
collectedRemoteRequests, enableBatchFetch = false)
curRequestSize = curBlocks.map(_.size).sum
}
case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
case ShuffleMergedBlockId(_, _, _) =>
if (curBlocks.size >= maxBlocksInFlightPerAddress) {
curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
@ -516,8 +516,8 @@ final class ShuffleBlockFetcherIterator(
if (curBlocks.nonEmpty) {
val (enableBatchFetch, forMergedMetas) = {
curBlocks.head.blockId match {
case ShuffleBlockChunkId(_, _, _) => (false, false)
case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
case ShuffleBlockChunkId(_, _, _, _) => (false, false)
case ShuffleMergedBlockId(_, _, _) => (false, true)
case _ => (doBatchFetch, false)
}
}
@ -901,9 +901,10 @@ final class ShuffleBlockFetcherIterator(
// a SuccessFetchResult or a FailureFetchResult.
result = null
case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs) =>
case PushMergedLocalMetaFetchResult(
shuffleId, shuffleMergeId, reduceId, bitmaps, localDirs) =>
// Fetch push-merged-local shuffle block data as multiple shuffle chunks
val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId)
try {
val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
localDirs)
@ -915,7 +916,8 @@ final class ShuffleBlockFetcherIterator(
numBlocksToFetch += bufs.size
bufs.zipWithIndex.foreach { case (buf, chunkId) =>
buf.retain()
val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
val shuffleChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId,
chunkId)
pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
@ -933,27 +935,29 @@ final class ShuffleBlockFetcherIterator(
}
result = null
case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address) =>
case PushMergedRemoteMetaFetchResult(
shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps, address) =>
// The original meta request is processed so we decrease numBlocksToFetch and
// numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the
// count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks.
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
numBlocksToFetch -= 1
val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
shuffleId, reduceId, blockSize, bitmaps)
shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps)
val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs)
fetchRequests ++= additionalRemoteReqs
// Set result to null to force another iteration.
result = null
case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address) =>
case PushMergedRemoteMetaFailedFetchResult(
shuffleId, shuffleMergeId, reduceId, address) =>
// The original meta request failed so we decrease numBlocksInFlightPerAddress by 1.
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
// If we fail to fetch the meta of a push-merged block, we fall back to fetching the
// original blocks.
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address)
ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), address)
// Set result to null to force another iteration.
result = null
}
@ -1421,6 +1425,8 @@ object ShuffleBlockFetcherIterator {
* Result of a successful fetch of meta information for a remote push-merged block.
*
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param blockSize size of each push-merged block.
* @param bitmaps bitmaps for every chunk.
@ -1428,6 +1434,7 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] case class PushMergedRemoteMetaFetchResult(
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
blockSize: Long,
bitmaps: Array[RoaringBitmap],
@ -1437,11 +1444,14 @@ object ShuffleBlockFetcherIterator {
* Result of a failure while fetching the meta information for a remote push-merged block.
*
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param address BlockManager that the meta was fetched from.
*/
private[storage] case class PushMergedRemoteMetaFailedFetchResult(
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
address: BlockManagerId) extends FetchResult
@ -1449,12 +1459,15 @@ object ShuffleBlockFetcherIterator {
* Result of a successful fetch of meta information for a push-merged-local block.
*
* @param shuffleId shuffle id.
* @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
* of shuffle by an indeterminate stage attempt.
* @param reduceId reduce id.
* @param bitmaps bitmaps for every chunk.
* @param localDirs local directories where the push-merged shuffle files are storedl
*/
private[storage] case class PushMergedLocalMetaFetchResult(
shuffleId: Int,
shuffleMergeId: Int,
reduceId: Int,
bitmaps: Array[RoaringBitmap],
localDirs: Array[String]) extends FetchResult

View file

@ -31,7 +31,7 @@ import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
private val conf = new SparkConf
@ -347,9 +347,9 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
bitmap.add(0)
bitmap.add(1)
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
bitmap, 1000L))
tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), 0,
bitmap, 1000L))
assert(tracker.getNumAvailableMergeResults(10) == 2)
tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000))
@ -386,12 +386,12 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0,
bitmap, 3000L))
slaveTracker.updateEpoch(masterTracker.getEpoch)
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1),
Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1),
(ShuffleBlockId(10, 2, 0), size1000, 2)))))
masterTracker.stop()
@ -431,7 +431,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0,
bitmap, 4000L))
slaveTracker.updateEpoch(masterTracker.getEpoch)
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
@ -523,7 +523,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
bitmap80.add(2)
bitmap80.add(3)
bitmap80.add(4)
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
bitmap80, 11))
val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
@ -535,7 +535,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
// Prepare another MergeStatus that merges only 1 out of 5 blocks
val bitmap20 = new RoaringBitmap()
bitmap20.add(0)
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
bitmap20, 2))
val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
@ -612,7 +612,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
}
masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000),
masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), 0,
bitmap1, 1000L))
val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf))
@ -646,13 +646,13 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
val bitmap1 = new RoaringBitmap()
bitmap1.add(0)
bitmap1.add(1)
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), 0,
bitmap1, 1000L))
val bitmap2 = new RoaringBitmap()
bitmap2.add(5)
bitmap2.add(6)
tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), 0,
bitmap2, 1000L))
assert(tracker.getNumAvailableMergeResults(10) == 2)
tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0))

View file

@ -315,7 +315,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
shuffleMapStage: ShuffleMapStage): Unit = {
if (shuffleMergeRegister) {
for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) {
val mergeStatuses = Seq((part, makeMergeStatus("")))
val mergeStatuses = Seq((part, makeMergeStatus("",
shuffleMapStage.shuffleDep.shuffleMergeId)))
handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
}
if (shuffleMergeFinalize) {
@ -3726,9 +3727,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
(Success, makeMapStatus("hostA", parts))
}.toSeq)
val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA"))))
scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA",
shuffleDep.shuffleMergeId))))
scheduler.handleShuffleMergeFinalized(shuffleMapStage)
scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA"))))
scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA",
shuffleDep.shuffleMergeId))))
assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 1)
}
@ -3779,6 +3782,71 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == parts)
}
test("SPARK-32923: handle stage failure for indeterminate map stage with push-based shuffle") {
initPushBasedShuffleConfs(conf)
DAGSchedulerSuite.clearMergerLocs
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5"))
val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
// Check status for all failedStages
val failedStages = scheduler.failedStages.toSeq
assert(failedStages.map(_.id) == Seq(1, 2))
// Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
assert(failedStages.collect {
case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
}.head.findMissingPartitions() == Seq(0))
// The result stage is still waiting for its 2 tasks to complete
assert(failedStages.collect {
case stage: ResultStage => stage
}.head.findMissingPartitions() == Seq(0, 1))
// shuffleMergeId for indeterminate stages would start from 1
assert(failedStages.collect {
case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId
}.forall(x => x == 1))
scheduler.resubmitFailedStages()
// The first task of the `shuffleMapRdd2` failed with fetch failure
runEvent(makeCompletionEvent(
taskSets(3).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"),
null))
val newFailedStages = scheduler.failedStages.toSeq
assert(newFailedStages.map(_.id) == Seq(0, 1))
// shuffleMergeId for indeterminate failed stages should be 2
assert(failedStages.collect {
case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId
}.forall(x => x == 2))
scheduler.resubmitFailedStages()
// First shuffle map stage resubmitted and reran all tasks.
assert(taskSets(4).stageId == 0)
assert(taskSets(4).stageAttemptId == 1)
assert(taskSets(4).tasks.length == 2)
// Finish all stage.
completeShuffleMapStageSuccessfully(0, 1, 2)
assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
// shuffleMergeId should be 2 for the attempt number 1 for stage 0
assert(mapOutputTracker.shuffleStatuses.get(shuffleId1).forall(
_.mergeStatuses.forall(x => x.shuffleMergeId == 2)))
assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId1) == 2)
completeShuffleMapStageSuccessfully(1, 2, 2, Seq("hostC", "hostD"))
assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
// shuffleMergeId should be 2 for the attempt number 2 for stage 1
assert(mapOutputTracker.shuffleStatuses.get(shuffleId2).forall(
_.mergeStatuses.forall(x => x.shuffleMergeId == 3)))
assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId2) == 2)
complete(taskSets(6), Seq((Success, 11), (Success, 12)))
// Job successful ended.
assert(results === Map(0 -> 11, 1 -> 12))
results.clear()
assertDataStructuresEmpty()
}
/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
@ -3843,8 +3911,8 @@ object DAGSchedulerSuite {
BlockManagerId(host + "-exec", host, 12345)
}
def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
def makeMergeStatus(host: String, shuffleMergeId: Int, size: Long = 1000): MergeStatus =
MergeStatus(makeBlockManagerId(host), shuffleMergeId, mock(classOf[RoaringBitmap]), size)
def addMergerLocs(locs: Seq[String]): Unit = {
locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }

View file

@ -96,7 +96,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
val largeBlockSize = 2 * 1024 * 1024
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs,
mock(classOf[TransportConf]))
assert(pushRequests.length == 3)
@ -107,7 +107,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf]))
assert(pushRequests.length == 2)
verifyPushRequests(pushRequests, Seq(6, 1024))
@ -117,7 +117,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
assert(pushRequests.length == 5)
verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
@ -220,7 +220,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
val errorHandler = pusher.createErrorHandler()
assert(
!errorHandler.shouldRetryError(new RuntimeException(
new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
new IllegalArgumentException(
BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))))
assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException())))
assert(
errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
@ -233,7 +234,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
val errorHandler = pusher.createErrorHandler()
assert(
!errorHandler.shouldLogError(new RuntimeException(
new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
new IllegalArgumentException(
BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX))))
assert(!errorHandler.shouldLogError(new RuntimeException(
new IllegalArgumentException(
BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
@ -284,7 +286,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
failBlock = false
// Fail the first block with the too late exception.
blockPushListener.onBlockPushFailure(blockId, new RuntimeException(
new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))
new IllegalArgumentException(
BlockPushErrorHandler.TOO_LATE_OR_STALE_BLOCK_PUSH_MESSAGE_SUFFIX)))
} else {
pushedBlocks += blockId
blockPushListener.onBlockPushSuccess(blockId, mock(classOf[ManagedBuffer]))

View file

@ -27,7 +27,7 @@ import org.mockito.invocation.InvocationOnMock
import org.roaringbitmap.RoaringBitmap
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite}
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.config
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo}
import org.apache.spark.storage._
@ -172,8 +172,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
test("getMergedBlockData should return expected FileSegmentManagedBuffer list") {
val shuffleId = 1
val shuffleMergeId = 0
val reduceId = 1
val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.data"
val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.data"
val dataFile = new File(tempDir.getAbsolutePath, dataFileName)
val out = new FileOutputStream(dataFile)
Utils.tryWithSafeFinally {
@ -181,12 +182,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
} {
out.close()
}
val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index"
val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.index"
generateMergedShuffleIndexFile(indexFileName)
val resolver = new IndexShuffleBlockResolver(conf, blockManager)
val dirs = Some(Array[String](tempDir.getAbsolutePath))
val managedBufferList =
resolver.getMergedBlockData(ShuffleBlockId(shuffleId, -1, reduceId), dirs)
resolver.getMergedBlockData(ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId),
dirs)
assert(managedBufferList.size === 3)
assert(managedBufferList(0).size === 10)
assert(managedBufferList(1).size === 0)
@ -195,8 +197,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
test("getMergedBlockMeta should return expected MergedBlockMeta") {
val shuffleId = 1
val shuffleMergeId = 0
val reduceId = 1
val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.meta"
val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.meta"
val metaFile = new File(tempDir.getAbsolutePath, metaFileName)
val chunkTracker = new RoaringBitmap()
val metaFileOutputStream = new FileOutputStream(metaFile)
@ -216,13 +219,14 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
}{
outMeta.close()
}
val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index"
val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_${shuffleMergeId}_$reduceId.index"
generateMergedShuffleIndexFile(indexFileName)
val resolver = new IndexShuffleBlockResolver(conf, blockManager)
val dirs = Some(Array[String](tempDir.getAbsolutePath))
val mergedBlockMeta =
resolver.getMergedBlockMeta(
ShuffleBlockId(shuffleId, MapOutputTracker.SHUFFLE_PUSH_MAP_ID, reduceId), dirs)
ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId),
dirs)
assert(mergedBlockMeta.getNumChunks === 3)
assert(mergedBlockMeta.readChunkBitmaps().size === 3)
assert(mergedBlockMeta.readChunkBitmaps()(0).contains(1))

View file

@ -105,37 +105,52 @@ class BlockIdSuite extends SparkFunSuite {
}
test("shuffle merged data") {
val id = ShuffleMergedDataBlockId("app_000", 8, 9)
assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 9))
assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 9))
assert(id.name === "shuffleMerged_app_000_8_9.data")
val id = ShuffleMergedDataBlockId("app_000", 8, 0, 9)
assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 0, 9))
assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 0, 9))
assert(id.name === "shuffleMerged_app_000_8_0_9.data")
assert(id.asRDDId === None)
assert(id.appId === "app_000")
assert(id.shuffleMergeId == 0)
assert(id.shuffleId=== 8)
assert(id.reduceId === 9)
assertSame(id, BlockId(id.toString))
}
test("shuffle merged index") {
val id = ShuffleMergedIndexBlockId("app_000", 8, 9)
assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 9))
assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 9))
assert(id.name === "shuffleMerged_app_000_8_9.index")
val id = ShuffleMergedIndexBlockId("app_000", 8, 0, 9)
assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 0, 9))
assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 0, 9))
assert(id.name === "shuffleMerged_app_000_8_0_9.index")
assert(id.asRDDId === None)
assert(id.appId === "app_000")
assert(id.shuffleId=== 8)
assert(id.shuffleMergeId == 0)
assert(id.reduceId === 9)
assertSame(id, BlockId(id.toString))
}
test("shuffle merged meta") {
val id = ShuffleMergedMetaBlockId("app_000", 8, 9)
assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 9))
assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 9))
assert(id.name === "shuffleMerged_app_000_8_9.meta")
val id = ShuffleMergedMetaBlockId("app_000", 8, 0, 9)
assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 0, 9))
assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 0, 9))
assert(id.name === "shuffleMerged_app_000_8_0_9.meta")
assert(id.asRDDId === None)
assert(id.appId === "app_000")
assert(id.shuffleId=== 8)
assert(id.shuffleMergeId == 0)
assert(id.reduceId === 9)
assertSame(id, BlockId(id.toString))
}
test("shuffle merged block") {
val id = ShuffleMergedBlockId(8, 0, 9)
assertSame(id, ShuffleMergedBlockId(8, 0, 9))
assertDifferent(id, ShuffleMergedBlockId(8, 1, 9))
assert(id.name === "shuffleMerged_8_0_9")
assert(id.asRDDId === None)
assert(id.shuffleId=== 8)
assert(id.shuffleMergeId == 0)
assert(id.reduceId === 9)
assertSame(id, BlockId(id.toString))
}
@ -224,10 +239,10 @@ class BlockIdSuite extends SparkFunSuite {
}
test("shuffle chunk") {
val id = ShuffleBlockChunkId(1, 1, 0)
assertSame(id, ShuffleBlockChunkId(1, 1, 0))
assertDifferent(id, ShuffleBlockChunkId(1, 1, 1))
assert(id.name === "shuffleChunk_1_1_0")
val id = ShuffleBlockChunkId(1, 0, 1, 0)
assertSame(id, ShuffleBlockChunkId(1, 0, 1, 0))
assertDifferent(id, ShuffleBlockChunkId(1, 0, 1, 1))
assert(id.name === "shuffleChunk_1_0_1_0")
assert(id.asRDDId === None)
assert(id.shuffleId === 1)
assert(id.reduceId === 1)

View file

@ -1059,15 +1059,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
test("SPARK-32922: fetch remote push-merged block meta") {
val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
(BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)),
toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L,
SHUFFLE_PUSH_MAP_ID)),
(BlockManagerId("remote-client-1", "remote-host-1", 1),
toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))
)
val blockChunks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer()
ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 1) -> createMockManagedBuffer()
)
val blocksSem = new Semaphore(0)
configureMockTransferForPushShuffle(blocksSem, blockChunks)
@ -1078,17 +1079,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap)
when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
Future {
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " +
s"port = ${invocation.getArguments()(1)}, " +
s"shuffleId = $shuffleId, reduceId = $reduceId")
s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId")
metaSem.acquire()
metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
}
})
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
@ -1101,10 +1103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
metaSem.release()
val (id3, _) = iterator.next()
blocksSem.acquire()
assert(id3 === ShuffleBlockChunkId(0, 2, 0))
assert(id3 === ShuffleBlockChunkId(0, 0, 2, 0))
val (id4, _) = iterator.next()
blocksSem.acquire()
assert(id4 === ShuffleBlockChunkId(0, 2, 1))
assert(id4 === ShuffleBlockChunkId(0, 0, 2, 1))
assert(!iterator.hasNext)
}
@ -1113,7 +1115,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1)
val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
(BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)),
toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L,
SHUFFLE_PUSH_MAP_ID)),
(remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)))
val blockChunks = Map[BlockId, ManagedBuffer](
@ -1127,13 +1130,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator)
val blocksSem = new Semaphore(0)
configureMockTransferForPushShuffle(blocksSem, blockChunks)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
Future {
metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error"))
metaListener.onFailure(shuffleId, shuffleMergeId, reduceId,
new RuntimeException("forced error"))
}
})
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
@ -1154,7 +1159,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1)
val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
(BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 2L,
SHUFFLE_PUSH_MAP_ID)))
val blockChunks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
@ -1165,13 +1171,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator)
val blocksSem = new Semaphore(0)
configureMockTransferForPushShuffle(blocksSem, blockChunks)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
Future {
metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error"))
metaListener.onFailure(shuffleId, shuffleMergeId, reduceId,
new RuntimeException("forced error"))
}
})
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
@ -1225,7 +1233,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER)
doReturn(Seq(createMockManagedBuffer(2))).when(blockManager)
.getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData)
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), dirsForMergedData)
// Get a valid chunk meta for this test
val bitmaps = Array(new RoaringBitmap)
@ -1236,7 +1244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
} else {
createMockPushMergedBlockMeta(bitmaps.length, bitmaps)
}
when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2),
when(blockManager.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2),
dirsForMergedData)).thenReturn(pushMergedBlockMeta)
when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
Seq((localBmId,
@ -1248,7 +1256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
(localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)),
(pushMergedBmId, toBlockList(
Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
}
private def verifyLocalBlocksFromFallback(iterator: ShuffleBlockFetcherIterator): Unit = {
@ -1270,7 +1278,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = prepareForFallbackToLocalBlocks(
blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
doThrow(new RuntimeException("Forced error")).when(blockManager)
.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), localDirs)
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
blockManager = Some(blockManager))
verifyLocalBlocksFromFallback(iterator)
@ -1295,7 +1303,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = prepareForFallbackToLocalBlocks(
blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
doThrow(new RuntimeException("Forced error")).when(blockManager)
.getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs)
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
blockManager = Some(blockManager))
verifyLocalBlocksFromFallback(iterator)
@ -1312,18 +1320,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1)
val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
(localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)),
(pushMergedBmId, toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2),
ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3)), 2L, SHUFFLE_PUSH_MAP_ID)))
(pushMergedBmId, toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2),
ShuffleMergedBlockId(0, 0, 3)), 2L, SHUFFLE_PUSH_MAP_ID)))
doThrow(new RuntimeException("Forced error")).when(blockManager)
.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2), localDirs)
// Create a valid chunk meta for partition 3
val bitmaps = Array(new RoaringBitmap)
bitmaps(0).add(1) // chunk 0 has mapId 1
doReturn(createMockPushMergedBlockMeta(bitmaps.length, bitmaps)).when(blockManager)
.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs)
.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 3), localDirs)
// Return valid buffer for chunk in partition 3
doReturn(Seq(createMockManagedBuffer(2))).when(blockManager)
.getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs)
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 3), localDirs)
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
blockManager = Some(blockManager))
val (id1, _) = iterator.next()
@ -1335,7 +1343,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val (id4, _) = iterator.next()
assert(id4 === ShuffleBlockId(0, 2, 2))
val (id5, _) = iterator.next()
assert(id5 === ShuffleBlockChunkId(0, 3, 0))
assert(id5 === ShuffleBlockChunkId(0, 0, 3, 0))
assert(!iterator.hasNext)
}
@ -1358,7 +1366,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Since bitmaps are null, this will fail reading the push-merged block meta causing fallback to
// initiate.
val pushMergedBlockMeta: MergedBlockMeta = createMockPushMergedBlockMeta(2, null)
when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2),
when(blockManager.getLocalMergedBlockMeta(ShuffleMergedBlockId(0, 0, 2),
dirsForMergedData)).thenReturn(pushMergedBlockMeta)
when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
Seq((localBmId,
@ -1366,7 +1374,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
(BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1), toBlockList(
Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
blockManager = Some(blockManager))
// 1st instance of iterator.next() returns the original shuffle block (0, 0, 2)
@ -1385,7 +1393,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = prepareForFallbackToLocalBlocks(blockManager, hostLocalDirs)
doThrow(new RuntimeException("Forced error")).when(blockManager)
.getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir"))
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), Array("local-dir"))
// host local read for a shuffle block
doReturn(createMockManagedBuffer()).when(blockManager)
.getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir"))
@ -1416,7 +1424,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
blockManager, hostLocalDirs) ++ hostLocalBlocks
doThrow(new RuntimeException("Forced error")).when(blockManager)
.getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir"))
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), Array("local-dir"))
// host Local read for this original shuffle block
doReturn(createMockManagedBuffer()).when(blockManager)
.getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir"))
@ -1452,7 +1460,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
doReturn(Seq({
new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100)
})).when(blockManager).getLocalMergedBlockData(
ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
ShuffleMergedBlockId(0, 0, 2), localDirs)
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
blockManager = Some(blockManager))
verifyLocalBlocksFromFallback(iterator)
@ -1466,7 +1474,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
val corruptBuffer = createMockManagedBuffer(2)
doReturn(Seq({corruptBuffer})).when(blockManager)
.getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
.getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs)
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
doReturn(corruptStream).when(corruptBuffer).createInputStream()
@ -1477,7 +1485,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
test("SPARK-32922: fallback to original blocks when failed to fetch remote shuffle chunk") {
val blockChunks = Map[BlockId, ManagedBuffer](
ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer()
@ -1489,13 +1497,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
bitmaps(1).add(4)
bitmaps(1).add(5)
val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, bitmaps)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
Future {
metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
}
})
val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
@ -1506,10 +1515,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
.thenReturn(fallbackBlocksByAddr.iterator)
val iterator = createShuffleBlockIteratorWithDefaults(Map(
BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) ->
toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, SHUFFLE_PUSH_MAP_ID)))
toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)),
12L, SHUFFLE_PUSH_MAP_ID)))
val (id1, _) = iterator.next()
blocksSem.acquire(1)
assert(id1 === ShuffleBlockChunkId(0, 2, 0))
assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0))
val (id3, _) = iterator.next()
blocksSem.acquire(3)
assert(id3 === ShuffleBlockId(0, 3, 2))
@ -1531,20 +1541,21 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksSem = new Semaphore(0)
configureMockTransferForPushShuffle(blocksSem, blockChunks)
val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, null)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
Future {
metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
}
})
val remoteMergedBlockMgrId = BlockManagerId(
SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1)
val iterator = createShuffleBlockIteratorWithDefaults(
Map(remoteMergedBlockMgrId -> toBlockList(
Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
Seq(ShuffleMergedBlockId(0, 0, 2)), 2L, SHUFFLE_PUSH_MAP_ID)))
val (id1, _) = iterator.next()
blocksSem.acquire(2)
assert(id1 === ShuffleBlockId(0, 0, 2))
@ -1556,10 +1567,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " +
"pending shuffle chunks immediately") {
val blockChunks = Map[BlockId, ManagedBuffer](
ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
// ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks.
ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 2, 3) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 2) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 3) -> createMockManagedBuffer(),
ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(),
@ -1574,17 +1585,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap)
when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
Future {
logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " +
s"port = ${invocation.getArguments()(1)}, " +
s"shuffleId = $shuffleId, reduceId = $reduceId")
s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId")
metaSem.release()
metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
}
})
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
@ -1596,12 +1608,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val iterator = createShuffleBlockIteratorWithDefaults(Map(
BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) ->
toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, SHUFFLE_PUSH_MAP_ID)),
toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)),
16L, SHUFFLE_PUSH_MAP_ID)),
maxBytesInFlight = 4)
metaSem.acquire(1)
val (id1, _) = iterator.next()
blocksSem.acquire(1)
assert(id1 === ShuffleBlockChunkId(0, 2, 0))
assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0))
val regularBlocks = new mutable.HashSet[BlockId]()
val (id2, _) = iterator.next()
blocksSem.acquire(1)
@ -1623,12 +1636,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " +
"pending shuffle chunks immediately which got deferred") {
val blockChunks = Map[BlockId, ManagedBuffer](
ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 0) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 1) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 2) -> createMockManagedBuffer(),
// ShuffleBlockChunkId(0, 2, 3) will cause failure as it is not in bock chunks
ShuffleBlockChunkId(0, 2, 4) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 2, 5) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 4) -> createMockManagedBuffer(),
ShuffleBlockChunkId(0, 0, 2, 5) -> createMockManagedBuffer(),
ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(),
@ -1642,17 +1655,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap)
when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
val metaListener = invocation.getArguments()(5).asInstanceOf[MergedBlocksMetaListener]
val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
val shuffleMergeId = invocation.getArguments()(3).asInstanceOf[Int]
val reduceId = invocation.getArguments()(4).asInstanceOf[Int]
Future {
logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " +
s"port = ${invocation.getArguments()(1)}, " +
s"shuffleId = $shuffleId, reduceId = $reduceId")
s"shuffleId = $shuffleId, shuffleMergeId = $shuffleMergeId, reduceId = $reduceId")
metaSem.release()
metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
metaListener.onSuccess(shuffleId, shuffleMergeId, reduceId, pushMergedBlockMeta)
}
})
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
@ -1664,17 +1678,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val iterator = createShuffleBlockIteratorWithDefaults(Map(
BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) ->
toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, SHUFFLE_PUSH_MAP_ID)),
toBlockList(Seq(ShuffleMergedBlockId(0, 0, 2)), 24L,
SHUFFLE_PUSH_MAP_ID)),
maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1)
metaSem.acquire(1)
val (id1, _) = iterator.next()
blocksSem.acquire(2)
assert(id1 === ShuffleBlockChunkId(0, 2, 0))
assert(id1 === ShuffleBlockChunkId(0, 0, 2, 0))
val (id2, _) = iterator.next()
assert(id2 === ShuffleBlockChunkId(0, 2, 1))
assert(id2 === ShuffleBlockChunkId(0, 0, 2, 1))
val (id3, _) = iterator.next()
blocksSem.acquire(1)
assert(id3 === ShuffleBlockChunkId(0, 2, 2))
assert(id3 === ShuffleBlockChunkId(0, 0, 2, 2))
val regularBlocks = new mutable.HashSet[BlockId]()
val (id4, _) = iterator.next()
blocksSem.acquire(1)