[SPARK-32915][CORE] Network-layer and shuffle RPC layer changes to support push shuffle blocks

### What changes were proposed in this pull request?

This is the first patch for SPIP SPARK-30602 for push-based shuffle.
Summary of changes:
* Introduce new API in ExternalBlockStoreClient to push blocks to a remote shuffle service.
* Leveraging the streaming upload functionality in SPARK-6237, it also enables the ExternalBlockHandler to delegate the handling of block push requests to MergedShuffleFileManager.
* Propose the API for MergedShuffleFileManager, where the core logic on the shuffle service side to handle block push requests is defined. The actual implementation of this API is deferred into a later RB to restrict the size of this PR.
* Introduce OneForOneBlockPusher to enable pushing blocks to remote shuffle services in shuffle RPC layer.
* New protocols in shuffle RPC layer to support the functionalities.

### Why are the changes needed?

Refer to the SPIP in SPARK-30602

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added unit tests.
The reference PR with the consolidated changes covering the complete implementation is also provided in SPARK-30602.
We have already verified the functionality and the improved performance as documented in the SPIP doc.

Lead-authored-by: Min Shen <mshenlinkedin.com>
Co-authored-by: Chandni Singh <chsinghlinkedin.com>
Co-authored-by: Ye Zhou <yezhoulinkedin.com>

Closes #29855 from Victsm/SPARK-32915.

Lead-authored-by: Min Shen <mshen@linkedin.com>
Co-authored-by: Chandni Singh <chsingh@linkedin.com>
Co-authored-by: Ye Zhou <yezhou@linkedin.com>
Co-authored-by: Chandni Singh <singh.chandni@gmail.com>
Co-authored-by: Min Shen <victor.nju@gmail.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
This commit is contained in:
Min Shen 2020-10-15 12:34:52 -05:00 committed by Mridul Muralidharan
parent b089fe5376
commit 82eea13c76
21 changed files with 1212 additions and 15 deletions

View file

@ -91,6 +91,10 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-crypto</artifactId>
</dependency>
<dependency>
<groupId>org.roaringbitmap</groupId>
<artifactId>RoaringBitmap</artifactId>
</dependency>
<!-- Test dependencies -->
<dependency>

View file

@ -17,9 +17,11 @@
package org.apache.spark.network.protocol;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import io.netty.buffer.ByteBuf;
import org.roaringbitmap.RoaringBitmap;
/** Provides a canonical set of Encoders for simple types. */
public class Encoders {
@ -44,6 +46,40 @@ public class Encoders {
}
}
/** Bitmaps are encoded with their serialization length followed by the serialization bytes. */
public static class Bitmaps {
public static int encodedLength(RoaringBitmap b) {
// Compress the bitmap before serializing it. Note that since BlockTransferMessage
// needs to invoke encodedLength first to figure out the length for the ByteBuf, it
// guarantees that the bitmap will always be compressed before being serialized.
b.trim();
b.runOptimize();
return b.serializedSizeInBytes();
}
public static void encode(ByteBuf buf, RoaringBitmap b) {
int encodedLength = b.serializedSizeInBytes();
// RoaringBitmap requires nio ByteBuffer for serde. We expose the netty ByteBuf as a nio
// ByteBuffer. Here, we need to explicitly manage the index so we can write into the
// ByteBuffer, and the write is reflected in the underneath ByteBuf.
b.serialize(buf.nioBuffer(buf.writerIndex(), encodedLength));
buf.writerIndex(buf.writerIndex() + encodedLength);
}
public static RoaringBitmap decode(ByteBuf buf) {
RoaringBitmap bitmap = new RoaringBitmap();
try {
bitmap.deserialize(buf.nioBuffer());
// RoaringBitmap deserialize does not advance the reader index of the underlying ByteBuf.
// Manually update the index here.
buf.readerIndex(buf.readerIndex() + bitmap.serializedSizeInBytes());
} catch (IOException e) {
throw new RuntimeException("Exception while decoding bitmap", e);
}
return bitmap;
}
}
/** Byte arrays are encoded with their length followed by bytes. */
public static class ByteArrays {
public static int encodedLength(byte[] arr) {
@ -135,4 +171,31 @@ public class Encoders {
return longs;
}
}
/** Bitmap arrays are encoded with the number of bitmaps followed by per-Bitmap encoding. */
public static class BitmapArrays {
public static int encodedLength(RoaringBitmap[] bitmaps) {
int totalLength = 4;
for (RoaringBitmap b : bitmaps) {
totalLength += Bitmaps.encodedLength(b);
}
return totalLength;
}
public static void encode(ByteBuf buf, RoaringBitmap[] bitmaps) {
buf.writeInt(bitmaps.length);
for (RoaringBitmap b : bitmaps) {
Bitmaps.encode(buf, b);
}
}
public static RoaringBitmap[] decode(ByteBuf buf) {
int numBitmaps = buf.readInt();
RoaringBitmap[] bitmaps = new RoaringBitmap[numBitmaps];
for (int i = 0; i < bitmaps.length; i ++) {
bitmaps[i] = Bitmaps.decode(buf);
}
return bitmaps;
}
}
}

View file

@ -57,6 +57,10 @@
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
<dependency>
<groupId>org.roaringbitmap</groupId>
<artifactId>RoaringBitmap</artifactId>
</dependency>
<!-- Test dependencies -->
<dependency>
@ -93,6 +97,11 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View file

@ -29,6 +29,7 @@ import com.codahale.metrics.MetricSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
@ -135,4 +136,24 @@ public abstract class BlockStoreClient implements Closeable {
hostLocalDirsCompletable.completeExceptionally(e);
}
}
/**
* Push a sequence of shuffle blocks in a best-effort manner to a remote node asynchronously.
* These shuffle blocks, along with blocks pushed by other clients, will be merged into
* per-shuffle partition merged shuffle files on the destination node.
*
* @param host the host of the remote node.
* @param port the port of the remote node.
* @param blockIds block ids to be pushed
* @param buffers buffers to be pushed
* @param listener the listener to receive block push status.
*/
public void pushBlocks(
String host,
int port,
String[] blockIds,
ManagedBuffer[] buffers,
BlockFetchingListener listener) {
throw new UnsupportedOperationException();
}
}

View file

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.net.ConnectException;
import com.google.common.base.Throwables;
/**
* Plugs into {@link RetryingBlockFetcher} to further control when an exception should be retried
* and logged.
* Note: {@link RetryingBlockFetcher} will delegate the exception to this handler only when
* - remaining retries < max retries
* - exception is an IOException
*/
public interface ErrorHandler {
boolean shouldRetryError(Throwable t);
default boolean shouldLogError(Throwable t) {
return true;
}
/**
* A no-op error handler instance.
*/
ErrorHandler NOOP_ERROR_HANDLER = t -> true;
/**
* The error handler for pushing shuffle blocks to remote shuffle services.
*/
class BlockPushErrorHandler implements ErrorHandler {
/**
* String constant used for generating exception messages indicating a block to be merged
* arrives too late on the server side, and also for later checking such exceptions on the
* client side. When we get a block push failure because of the block arrives too late, we
* will not retry pushing the block nor log the exception on the client side.
*/
public static final String TOO_LATE_MESSAGE_SUFFIX =
"received after merged shuffle is finalized";
/**
* String constant used for generating exception messages indicating the server couldn't
* append a block after all available attempts due to collision with other blocks belonging
* to the same shuffle partition, and also for later checking such exceptions on the client
* side. When we get a block push failure because of the block couldn't be written due to
* this reason, we will not log the exception on the client side.
*/
public static final String BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX =
"Couldn't find an opportunity to write block";
@Override
public boolean shouldRetryError(Throwable t) {
// If it is a connection time out or a connection closed exception, no need to retry.
if (t.getCause() != null && t.getCause() instanceof ConnectException) {
return false;
}
// If the block is too late, there is no need to retry it
return !Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
}
@Override
public boolean shouldLogError(Throwable t) {
String errorStackTrace = Throwables.getStackTraceAsString(t);
return !errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) &&
!errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX);
}
}
}

View file

@ -32,6 +32,7 @@ import com.codahale.metrics.MetricSet;
import com.codahale.metrics.Timer;
import com.codahale.metrics.Counter;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -61,11 +62,21 @@ public class ExternalBlockHandler extends RpcHandler {
final ExternalShuffleBlockResolver blockManager;
private final OneForOneStreamManager streamManager;
private final ShuffleMetrics metrics;
private final MergedShuffleFileManager mergeManager;
public ExternalBlockHandler(TransportConf conf, File registeredExecutorFile)
throws IOException {
this(new OneForOneStreamManager(),
new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
new ExternalShuffleBlockResolver(conf, registeredExecutorFile),
new NoOpMergedShuffleFileManager());
}
public ExternalBlockHandler(
TransportConf conf,
File registeredExecutorFile,
MergedShuffleFileManager mergeManager) throws IOException {
this(new OneForOneStreamManager(),
new ExternalShuffleBlockResolver(conf, registeredExecutorFile), mergeManager);
}
@VisibleForTesting
@ -78,9 +89,19 @@ public class ExternalBlockHandler extends RpcHandler {
public ExternalBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager) {
this(streamManager, blockManager, new NoOpMergedShuffleFileManager());
}
/** Enables mocking out the StreamManager, BlockManager, and MergeManager. */
@VisibleForTesting
public ExternalBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager,
MergedShuffleFileManager mergeManager) {
this.metrics = new ShuffleMetrics();
this.streamManager = streamManager;
this.blockManager = blockManager;
this.mergeManager = mergeManager;
}
@Override
@ -89,6 +110,21 @@ public class ExternalBlockHandler extends RpcHandler {
handleMessage(msgObj, client, callback);
}
@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer messageHeader,
RpcResponseCallback callback) {
BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader);
if (msgObj instanceof PushBlockStream) {
PushBlockStream message = (PushBlockStream) msgObj;
checkAuth(client, message.appId);
return mergeManager.receiveBlockDataAsStream(message);
} else {
throw new UnsupportedOperationException("Unexpected message with #receiveStream: " + msgObj);
}
}
protected void handleMessage(
BlockTransferMessage msgObj,
TransportClient client,
@ -139,6 +175,7 @@ public class ExternalBlockHandler extends RpcHandler {
RegisterExecutor msg = (RegisterExecutor) msgObj;
checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
mergeManager.registerExecutor(msg.appId, msg.executorInfo.localDirs);
callback.onSuccess(ByteBuffer.wrap(new byte[0]));
} finally {
responseDelayContext.stop();
@ -156,6 +193,20 @@ public class ExternalBlockHandler extends RpcHandler {
Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
} else if (msgObj instanceof FinalizeShuffleMerge) {
final Timer.Context responseDelayContext =
metrics.finalizeShuffleMergeLatencyMillis.time();
FinalizeShuffleMerge msg = (FinalizeShuffleMerge) msgObj;
try {
checkAuth(client, msg.appId);
MergeStatuses statuses = mergeManager.finalizeShuffleMerge(msg);
callback.onSuccess(statuses.toByteBuffer());
} catch(IOException e) {
throw new RuntimeException(String.format("Error while finalizing shuffle merge "
+ "for application %s shuffle %d", msg.appId, msg.shuffleId), e);
} finally {
responseDelayContext.stop();
}
} else {
throw new UnsupportedOperationException("Unexpected message: " + msgObj);
}
@ -225,6 +276,8 @@ public class ExternalBlockHandler extends RpcHandler {
private final Timer openBlockRequestLatencyMillis = new Timer();
// Time latency for executor registration latency in ms
private final Timer registerExecutorRequestLatencyMillis = new Timer();
// Time latency for processing finalize shuffle merge request latency in ms
private final Timer finalizeShuffleMergeLatencyMillis = new Timer();
// Block transfer rate in byte per second
private final Meter blockTransferRateBytes = new Meter();
// Number of active connections to the shuffle service
@ -236,6 +289,7 @@ public class ExternalBlockHandler extends RpcHandler {
allMetrics = new HashMap<>();
allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis);
allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
allMetrics.put("registeredExecutorsSize",
(Gauge<Integer>) () -> blockManager.getRegisteredExecutorsSize());
@ -373,6 +427,54 @@ public class ExternalBlockHandler extends RpcHandler {
}
}
/**
* Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle
* is not enabled.
*/
private static class NoOpMergedShuffleFileManager implements MergedShuffleFileManager {
@Override
public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
@Override
public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
@Override
public void registerApplication(String appId, String user) {
// No-op. Do nothing.
}
@Override
public void registerExecutor(String appId, String[] localDirs) {
// No-Op. Do nothing.
}
@Override
public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
@Override
public ManagedBuffer getMergedBlockData(
String appId, int shuffleId, int reduceId, int chunkId) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
@Override
public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
@Override
public String[] getMergedBlockDirs(String appId) {
throw new UnsupportedOperationException("Cannot handle shuffle block merge");
}
}
@Override
public void channelActive(TransportClient client) {
metrics.activeConnections.inc();

View file

@ -20,21 +20,24 @@ package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import com.codahale.metrics.MetricSet;
import com.google.common.collect.Lists;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.shuffle.protocol.*;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.shuffle.protocol.*;
import org.apache.spark.network.util.TransportConf;
/**
@ -43,6 +46,8 @@ import org.apache.spark.network.util.TransportConf;
* (via BlockTransferService), which has the downside of losing the data if we lose the executors.
*/
public class ExternalBlockStoreClient extends BlockStoreClient {
private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler();
private final TransportConf conf;
private final boolean authEnabled;
private final SecretKeyHolder secretKeyHolder;
@ -90,12 +95,12 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
try {
int maxRetries = conf.maxIORetries();
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> {
(inputBlockId, inputListener) -> {
// Unless this client is closed.
if (clientFactory != null) {
TransportClient client = clientFactory.createClient(host, port, maxRetries > 0);
new OneForOneBlockFetcher(client, appId, execId,
blockIds1, listener1, conf, downloadFileManager).start();
inputBlockId, inputListener, conf, downloadFileManager).start();
} else {
logger.info("This clientFactory was closed. Skipping further block fetch retries.");
}
@ -116,6 +121,43 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
}
}
@Override
public void pushBlocks(
String host,
int port,
String[] blockIds,
ManagedBuffer[] buffers,
BlockFetchingListener listener) {
checkInit();
assert blockIds.length == buffers.length : "Number of block ids and buffers do not match.";
Map<String, ManagedBuffer> buffersWithId = new HashMap<>();
for (int i = 0; i < blockIds.length; i++) {
buffersWithId.put(blockIds[i], buffers[i]);
}
logger.debug("Push {} shuffle blocks to {}:{}", blockIds.length, host, port);
try {
RetryingBlockFetcher.BlockFetchStarter blockPushStarter =
(inputBlockId, inputListener) -> {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockPusher(client, appId, inputBlockId, inputListener, buffersWithId)
.start();
};
int maxRetries = conf.maxIORetries();
if (maxRetries > 0) {
new RetryingBlockFetcher(
conf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start();
} else {
blockPushStarter.createAndStart(blockIds, listener);
}
} catch (Exception e) {
logger.error("Exception while beginning pushBlocks", e);
for (String blockId : blockIds) {
listener.onBlockFetchFailure(blockId, e);
}
}
}
@Override
public MetricSet shuffleMetrics() {
checkInit();

View file

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.roaringbitmap.RoaringBitmap;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.protocol.Encoders;
/**
* Contains meta information for a merged block. Currently this information constitutes:
* 1. Number of chunks in a merged shuffle block.
* 2. Bitmaps for each chunk in the merged block. A chunk bitmap contains all the mapIds that were
* merged to that merged block chunk.
*/
public class MergedBlockMeta {
private final int numChunks;
private final ManagedBuffer chunksBitmapBuffer;
public MergedBlockMeta(int numChunks, ManagedBuffer chunksBitmapBuffer) {
this.numChunks = numChunks;
this.chunksBitmapBuffer = Preconditions.checkNotNull(chunksBitmapBuffer);
}
public int getNumChunks() {
return numChunks;
}
public ManagedBuffer getChunksBitmapBuffer() {
return chunksBitmapBuffer;
}
public RoaringBitmap[] readChunkBitmaps() throws IOException {
ByteBuf buf = Unpooled.wrappedBuffer(chunksBitmapBuffer.nioByteBuffer());
List<RoaringBitmap> bitmaps = new ArrayList<>();
while(buf.isReadable()) {
bitmaps.add(Encoders.Bitmaps.decode(buf));
}
assert (bitmaps.size() == numChunks);
return bitmaps.toArray(new RoaringBitmap[0]);
}
}

View file

@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.io.IOException;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
import org.apache.spark.network.shuffle.protocol.MergeStatuses;
import org.apache.spark.network.shuffle.protocol.PushBlockStream;
/**
* The MergedShuffleFileManager is used to process push based shuffle when enabled. It works
* along side {@link ExternalBlockHandler} and serves as an RPCHandler for
* {@link org.apache.spark.network.server.RpcHandler#receiveStream}, where it processes the
* remotely pushed streams of shuffle blocks to merge them into merged shuffle files. Right
* now, support for push based shuffle is only implemented for external shuffle service in
* YARN mode.
*/
public interface MergedShuffleFileManager {
/**
* Provides the stream callback used to process a remotely pushed block. The callback is
* used by the {@link org.apache.spark.network.client.StreamInterceptor} installed on the
* channel to process the block data in the channel outside of the message frame.
*
* @param msg metadata of the remotely pushed blocks. This is processed inside the message frame
* @return A stream callback to process the block data in streaming fashion as it arrives
*/
StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg);
/**
* Handles the request to finalize shuffle merge for a given shuffle.
*
* @param msg contains appId and shuffleId to uniquely identify a shuffle to be finalized
* @return The statuses of the merged shuffle partitions for the given shuffle on this
* shuffle service
* @throws IOException
*/
MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException;
/**
* Registers an application when it starts. It also stores the username which is necessary
* for generating the host local directories for merged shuffle files.
* Right now, this is invoked by YarnShuffleService.
*
* @param appId application ID
* @param user username
*/
void registerApplication(String appId, String user);
/**
* Registers an executor with its local dir list when it starts. This provides the specific path
* so MergedShuffleFileManager knows where to store and look for shuffle data for a
* given application. It is invoked by the RPC call when executor tries to register with the
* local shuffle service.
*
* @param appId application ID
* @param localDirs The list of local dirs that this executor gets granted from NodeManager
*/
void registerExecutor(String appId, String[] localDirs);
/**
* Invoked when an application finishes. This cleans up any remaining metadata associated with
* this application, and optionally deletes the application specific directory path.
*
* @param appId application ID
* @param cleanupLocalDirs flag indicating whether MergedShuffleFileManager should handle
* deletion of local dirs itself.
*/
void applicationRemoved(String appId, boolean cleanupLocalDirs);
/**
* Get the buffer for a given merged shuffle chunk when serving merged shuffle to reducers
*
* @param appId application ID
* @param shuffleId shuffle ID
* @param reduceId reducer ID
* @param chunkId merged shuffle file chunk ID
* @return The {@link ManagedBuffer} for the given merged shuffle chunk
*/
ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId);
/**
* Get the meta information of a merged block.
*
* @param appId application ID
* @param shuffleId shuffle ID
* @param reduceId reducer ID
* @return meta information of a merged block
*/
MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId);
/**
* Get the local directories which stores the merged shuffle files.
*
* @param appId application ID
*/
String[] getMergedBlockDirs(String appId);
}

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.shuffle.protocol.PushBlockStream;
/**
* Similar to {@link OneForOneBlockFetcher}, but for pushing blocks to remote shuffle service to
* be merged instead of for fetching them from remote shuffle services. This is used by
* ShuffleWriter when the block push process is initiated. The supplied BlockFetchingListener
* is used to handle the success or failure in pushing each blocks.
*/
public class OneForOneBlockPusher {
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockPusher.class);
private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler();
private final TransportClient client;
private final String appId;
private final String[] blockIds;
private final BlockFetchingListener listener;
private final Map<String, ManagedBuffer> buffers;
public OneForOneBlockPusher(
TransportClient client,
String appId,
String[] blockIds,
BlockFetchingListener listener,
Map<String, ManagedBuffer> buffers) {
this.client = client;
this.appId = appId;
this.blockIds = blockIds;
this.listener = listener;
this.buffers = buffers;
}
private class BlockPushCallback implements RpcResponseCallback {
private int index;
private String blockId;
BlockPushCallback(int index, String blockId) {
this.index = index;
this.blockId = blockId;
}
@Override
public void onSuccess(ByteBuffer response) {
// On receipt of a successful block push
listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(ByteBuffer.allocate(0)));
}
@Override
public void onFailure(Throwable e) {
// Since block push is best effort, i.e., if we encountered a block push failure that's not
// retriable or exceeding the max retires, we should not fail all remaining block pushes.
// The best effort nature makes block push tolerable of a partial completion. Thus, we only
// fail the block that's actually failed. Not that, on the RetryingBlockFetcher side, once
// retry is initiated, it would still invalidate the previous active retry listener, and
// retry all outstanding blocks. We are preventing forwarding unnecessary block push failures
// to the parent listener of the retry listener. The only exceptions would be if the block
// push failure is due to block arriving on the server side after merge finalization, or the
// client fails to establish connection to the server side. In both cases, we would fail all
// remaining blocks.
if (PUSH_ERROR_HANDLER.shouldRetryError(e)) {
String[] targetBlockId = Arrays.copyOfRange(blockIds, index, index + 1);
failRemainingBlocks(targetBlockId, e);
} else {
String[] targetBlockId = Arrays.copyOfRange(blockIds, index, blockIds.length);
failRemainingBlocks(targetBlockId, e);
}
}
}
private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
for (String blockId : failedBlockIds) {
try {
listener.onBlockFetchFailure(blockId, e);
} catch (Exception e2) {
logger.error("Error in block push failure callback", e2);
}
}
}
/**
* Begins the block pushing process, calling the listener with every block pushed.
*/
public void start() {
logger.debug("Start pushing {} blocks", blockIds.length);
for (int i = 0; i < blockIds.length; i++) {
assert buffers.containsKey(blockIds[i]) : "Could not find the block buffer for block "
+ blockIds[i];
ByteBuffer header = new PushBlockStream(appId, blockIds[i], i).toByteBuffer();
client.uploadStream(new NioManagedBuffer(header), buffers.get(blockIds[i]),
new BlockPushCallback(i, blockIds[i]));
}
}
}

View file

@ -99,11 +99,14 @@ public class RetryingBlockFetcher {
*/
private RetryingBlockFetchListener currentListener;
private final ErrorHandler errorHandler;
public RetryingBlockFetcher(
TransportConf conf,
RetryingBlockFetcher.BlockFetchStarter fetchStarter,
String[] blockIds,
BlockFetchingListener listener) {
BlockFetchingListener listener,
ErrorHandler errorHandler) {
this.fetchStarter = fetchStarter;
this.listener = listener;
this.maxRetries = conf.maxIORetries();
@ -111,6 +114,15 @@ public class RetryingBlockFetcher {
this.outstandingBlocksIds = Sets.newLinkedHashSet();
Collections.addAll(outstandingBlocksIds, blockIds);
this.currentListener = new RetryingBlockFetchListener();
this.errorHandler = errorHandler;
}
public RetryingBlockFetcher(
TransportConf conf,
BlockFetchStarter fetchStarter,
String[] blockIds,
BlockFetchingListener listener) {
this(conf, fetchStarter, blockIds, listener, ErrorHandler.NOOP_ERROR_HANDLER);
}
/**
@ -178,7 +190,7 @@ public class RetryingBlockFetcher {
boolean isIOException = e instanceof IOException
|| (e.getCause() != null && e.getCause() instanceof IOException);
boolean hasRemainingRetries = retryCount < maxRetries;
return isIOException && hasRemainingRetries;
return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e);
}
/**
@ -215,8 +227,15 @@ public class RetryingBlockFetcher {
if (shouldRetry(exception)) {
initiateRetry();
} else {
logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)",
blockId, retryCount), exception);
if (errorHandler.shouldLogError(exception)) {
logger.error(
String.format("Failed to fetch block %s, and will not retry (%s retries)",
blockId, retryCount), exception);
} else {
logger.debug(
String.format("Failed to fetch block %s, and will not retry (%s retries)",
blockId, retryCount), exception);
}
outstandingBlocksIds.remove(blockId);
shouldForwardFailure = true;
}

View file

@ -47,7 +47,8 @@ public abstract class BlockTransferMessage implements Encodable {
public enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11);
FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11),
PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14);
private final byte id;
@ -78,6 +79,9 @@ public abstract class BlockTransferMessage implements Encodable {
case 9: return FetchShuffleBlocks.decode(buf);
case 10: return GetLocalDirsForExecutors.decode(buf);
case 11: return LocalDirsForExecutors.decode(buf);
case 12: return PushBlockStream.decode(buf);
case 13: return FinalizeShuffleMerge.decode(buf);
case 14: return MergeStatuses.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}

View file

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encoders;
/**
* Request to finalize merge for a given shuffle.
* Returns {@link MergeStatuses}
*/
public class FinalizeShuffleMerge extends BlockTransferMessage {
public final String appId;
public final int shuffleId;
public FinalizeShuffleMerge(
String appId,
int shuffleId) {
this.appId = appId;
this.shuffleId = shuffleId;
}
@Override
protected BlockTransferMessage.Type type() {
return Type.FINALIZE_SHUFFLE_MERGE;
}
@Override
public int hashCode() {
return Objects.hashCode(appId, shuffleId);
}
@Override
public String toString() {
return Objects.toStringHelper(this)
.add("appId", appId)
.add("shuffleId", shuffleId)
.toString();
}
@Override
public boolean equals(Object other) {
if (other != null && other instanceof FinalizeShuffleMerge) {
FinalizeShuffleMerge o = (FinalizeShuffleMerge) other;
return Objects.equal(appId, o.appId)
&& shuffleId == o.shuffleId;
}
return false;
}
@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId) + 4;
}
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
buf.writeInt(shuffleId);
}
public static FinalizeShuffleMerge decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
int shuffleId = buf.readInt();
return new FinalizeShuffleMerge(appId, shuffleId);
}
}

View file

@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import java.util.Arrays;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.roaringbitmap.RoaringBitmap;
import org.apache.spark.network.protocol.Encoders;
/**
* Result returned by an ExternalShuffleService to the DAGScheduler. This represents the result
* of all the remote shuffle block merge operations performed by an ExternalShuffleService
* for a given shuffle ID. It includes the shuffle ID, an array of bitmaps each representing
* the set of mapper partition blocks that are merged for a given reducer partition, an array
* of reducer IDs, and an array of merged shuffle partition sizes. The 3 arrays list information
* about all the reducer partitions merged by the ExternalShuffleService in the same order.
*/
public class MergeStatuses extends BlockTransferMessage {
/** Shuffle ID **/
public final int shuffleId;
/**
* Array of bitmaps tracking the set of mapper partition blocks merged for each
* reducer partition
*/
public final RoaringBitmap[] bitmaps;
/** Array of reducer IDs **/
public final int[] reduceIds;
/**
* Array of merged shuffle partition block size. Each represents the total size of all
* merged shuffle partition blocks for one reducer partition.
* **/
public final long[] sizes;
public MergeStatuses(
int shuffleId,
RoaringBitmap[] bitmaps,
int[] reduceIds,
long[] sizes) {
this.shuffleId = shuffleId;
this.bitmaps = bitmaps;
this.reduceIds = reduceIds;
this.sizes = sizes;
}
@Override
protected Type type() {
return Type.MERGE_STATUSES;
}
@Override
public int hashCode() {
int objectHashCode = Objects.hashCode(shuffleId);
return (objectHashCode * 41 + Arrays.hashCode(reduceIds) * 41
+ Arrays.hashCode(bitmaps) * 41 + Arrays.hashCode(sizes));
}
@Override
public String toString() {
return Objects.toStringHelper(this)
.add("shuffleId", shuffleId)
.add("reduceId size", reduceIds.length)
.toString();
}
@Override
public boolean equals(Object other) {
if (other != null && other instanceof MergeStatuses) {
MergeStatuses o = (MergeStatuses) other;
return Objects.equal(shuffleId, o.shuffleId)
&& Arrays.equals(bitmaps, o.bitmaps)
&& Arrays.equals(reduceIds, o.reduceIds)
&& Arrays.equals(sizes, o.sizes);
}
return false;
}
@Override
public int encodedLength() {
return 4 // int
+ Encoders.BitmapArrays.encodedLength(bitmaps)
+ Encoders.IntArrays.encodedLength(reduceIds)
+ Encoders.LongArrays.encodedLength(sizes);
}
@Override
public void encode(ByteBuf buf) {
buf.writeInt(shuffleId);
Encoders.BitmapArrays.encode(buf, bitmaps);
Encoders.IntArrays.encode(buf, reduceIds);
Encoders.LongArrays.encode(buf, sizes);
}
public static MergeStatuses decode(ByteBuf buf) {
int shuffleId = buf.readInt();
RoaringBitmap[] bitmaps = Encoders.BitmapArrays.decode(buf);
int[] reduceIds = Encoders.IntArrays.decode(buf);
long[] sizes = Encoders.LongArrays.decode(buf);
return new MergeStatuses(shuffleId, bitmaps, reduceIds, sizes);
}
}

View file

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encoders;
// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
/**
* Request to push a block to a remote shuffle service to be merged in push based shuffle.
* The remote shuffle service will also include this message when responding the push requests.
*/
public class PushBlockStream extends BlockTransferMessage {
public final String appId;
public final String blockId;
// Similar to the chunkIndex in StreamChunkId, indicating the index of a block in a batch of
// blocks to be pushed.
public final int index;
public PushBlockStream(String appId, String blockId, int index) {
this.appId = appId;
this.blockId = blockId;
this.index = index;
}
@Override
protected Type type() {
return Type.PUSH_BLOCK_STREAM;
}
@Override
public int hashCode() {
return Objects.hashCode(appId, blockId, index);
}
@Override
public String toString() {
return Objects.toStringHelper(this)
.add("appId", appId)
.add("blockId", blockId)
.add("index", index)
.toString();
}
@Override
public boolean equals(Object other) {
if (other != null && other instanceof PushBlockStream) {
PushBlockStream o = (PushBlockStream) other;
return Objects.equal(appId, o.appId)
&& Objects.equal(blockId, o.blockId)
&& index == o.index;
}
return false;
}
@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId)
+ Encoders.Strings.encodedLength(blockId) + 4;
}
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
Encoders.Strings.encode(buf, blockId);
buf.writeInt(index);
}
public static PushBlockStream decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
String blockId = Encoders.Strings.decode(buf);
int index = buf.readInt();
return new PushBlockStream(appId, blockId, index);
}
}

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.net.ConnectException;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Test suite for {@link ErrorHandler}
*/
public class ErrorHandlerSuite {
@Test
public void testPushErrorRetry() {
ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
assertFalse(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
assertFalse(handler.shouldRetryError(new RuntimeException(new ConnectException())));
assertTrue(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
assertTrue(handler.shouldRetryError(new Throwable()));
}
@Test
public void testPushErrorLogging() {
ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
assertTrue(handler.shouldLogError(new Throwable()));
}
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Iterator;
@ -25,6 +26,7 @@ import com.codahale.metrics.Timer;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.roaringbitmap.RoaringBitmap;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
@ -39,6 +41,8 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
import org.apache.spark.network.shuffle.protocol.MergeStatuses;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
@ -50,6 +54,7 @@ public class ExternalBlockHandlerSuite {
OneForOneStreamManager streamManager;
ExternalShuffleBlockResolver blockResolver;
RpcHandler handler;
MergedShuffleFileManager mergedShuffleManager;
ManagedBuffer[] blockMarkers = {
new NioManagedBuffer(ByteBuffer.wrap(new byte[3])),
new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
@ -59,17 +64,20 @@ public class ExternalBlockHandlerSuite {
public void beforeEach() {
streamManager = mock(OneForOneStreamManager.class);
blockResolver = mock(ExternalShuffleBlockResolver.class);
handler = new ExternalBlockHandler(streamManager, blockResolver);
mergedShuffleManager = mock(MergedShuffleFileManager.class);
handler = new ExternalBlockHandler(streamManager, blockResolver, mergedShuffleManager);
}
@Test
public void testRegisterExecutor() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
String[] localDirs = new String[] {"/a", "/b"};
ExecutorShuffleInfo config = new ExecutorShuffleInfo(localDirs, 16, "sort");
ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
handler.receive(client, registerMessage, callback);
verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
verify(mergedShuffleManager, times(1)).registerExecutor("app0", localDirs);
verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
verify(callback, never()).onFailure(any(Throwable.class));
@ -222,4 +230,32 @@ public class ExternalBlockHandlerSuite {
verify(callback, never()).onSuccess(any(ByteBuffer.class));
verify(callback, never()).onFailure(any(Throwable.class));
}
@Test
public void testFinalizeShuffleMerge() throws IOException {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 0);
RoaringBitmap bitmap = RoaringBitmap.bitmapOf(0, 1, 2);
MergeStatuses statuses = new MergeStatuses(0, new RoaringBitmap[]{bitmap},
new int[]{3}, new long[]{30});
when(mergedShuffleManager.finalizeShuffleMerge(req)).thenReturn(statuses);
ByteBuffer reqBuf = req.toByteBuffer();
handler.receive(client, reqBuf, callback);
verify(mergedShuffleManager, times(1)).finalizeShuffleMerge(req);
ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure(any());
MergeStatuses mergeStatuses =
(MergeStatuses) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
assertEquals(mergeStatuses, statuses);
Timer finalizeShuffleMergeLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
.getAllMetrics()
.getMetrics()
.get("finalizeShuffleMergeLatencyMillis");
assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount());
}
}

View file

@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import com.google.common.collect.Maps;
import io.netty.buffer.Unpooled;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.AdditionalMatchers.*;
import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.PushBlockStream;
public class OneForOneBlockPusherSuite {
@Test
public void testPushOne() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", "shuffle_0_0_0", 0)));
verify(listener).onBlockFetchSuccess(eq("shuffle_0_0_0"), any());
}
@Test
public void testPushThree() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", "b0", 0),
new PushBlockStream("app-id", "b1", 1),
new PushBlockStream("app-id", "b2", 2)));
for (int i = 0; i < 3; i ++) {
verify(listener, times(1)).onBlockFetchSuccess(eq("b" + i), any());
}
}
@Test
public void testServerFailures() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", "b0", 0),
new PushBlockStream("app-id", "b1", 1),
new PushBlockStream("app-id", "b2", 2)));
verify(listener, times(1)).onBlockFetchSuccess(eq("b0"), any());
verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any());
}
@Test
public void testHandlingRetriableFailures() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("b1", null);
blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = pushBlocks(
blocks,
blockIds,
Arrays.asList(new PushBlockStream("app-id", "b0", 0),
new PushBlockStream("app-id", "b1", 1),
new PushBlockStream("app-id", "b2", 2)));
verify(listener, times(1)).onBlockFetchSuccess(eq("b0"), any());
verify(listener, times(0)).onBlockFetchSuccess(not(eq("b0")), any());
verify(listener, times(0)).onBlockFetchFailure(eq("b0"), any());
verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any());
}
/**
* Begins a push on the given set of blocks by mocking the response from server side.
* If a block is an empty byte, a server side retriable exception will be thrown.
* If a block is null, a non-retriable exception will be thrown.
*/
private static BlockFetchingListener pushBlocks(
LinkedHashMap<String, ManagedBuffer> blocks,
String[] blockIds,
Iterable<BlockTransferMessage> expectMessages) {
TransportClient client = mock(TransportClient.class);
BlockFetchingListener listener = mock(BlockFetchingListener.class);
OneForOneBlockPusher pusher =
new OneForOneBlockPusher(client, "app-id", blockIds, listener, blocks);
Iterator<Map.Entry<String, ManagedBuffer>> blockIterator = blocks.entrySet().iterator();
Iterator<BlockTransferMessage> msgIterator = expectMessages.iterator();
doAnswer(invocation -> {
ByteBuffer header = ((ManagedBuffer) invocation.getArguments()[0]).nioByteBuffer();
BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(header);
RpcResponseCallback callback = (RpcResponseCallback) invocation.getArguments()[2];
Map.Entry<String, ManagedBuffer> entry = blockIterator.next();
ManagedBuffer block = entry.getValue();
if (block != null && block.nioByteBuffer().capacity() > 0) {
callback.onSuccess(header);
} else if (block != null) {
callback.onFailure(new RuntimeException("Failed " + entry.getKey()
+ ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX));
} else {
callback.onFailure(new RuntimeException("Quick fail " + entry.getKey()
+ ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
}
assertEquals(msgIterator.next(), message);
return null;
}).when(client).uploadStream(any(ManagedBuffer.class), any(), any(RpcResponseCallback.class));
pusher.start();
return listener;
}
}

View file

@ -61,7 +61,8 @@ class ExternalShuffleServiceMetricsSuite extends SparkFunSuite {
"registeredExecutorsSize",
"registerExecutorRequestLatencyMillis",
"shuffle-server.usedDirectMemory",
"shuffle-server.usedHeapMemory")
"shuffle-server.usedHeapMemory",
"finalizeShuffleMergeLatencyMillis")
)
}
}

View file

@ -40,7 +40,7 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers {
val allMetrics = Set(
"openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis",
"blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections",
"numCaughtExceptions")
"numCaughtExceptions", "finalizeShuffleMergeLatencyMillis")
metrics.getMetrics.keySet().asScala should be (allMetrics)
}

View file

@ -405,6 +405,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
"openBlockRequestLatencyMillis",
"registeredExecutorsSize",
"registerExecutorRequestLatencyMillis",
"finalizeShuffleMergeLatencyMillis",
"shuffle-server.usedDirectMemory",
"shuffle-server.usedHeapMemory"
))