From 8ce1e344e58dbfbddecd9e9fd9f0a5a6f15dbea9 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Sun, 20 Jun 2021 17:22:37 -0500 Subject: [PATCH] [SPARK-35671][SHUFFLE][CORE] Add support in the ESS to serve merged shuffle block meta and data to executors ### What changes were proposed in this pull request? This adds support in the ESS to serve merged shuffle block meta and data requests to executors. This change is needed for fetching remote merged shuffle data from the remote shuffle services. This is part of push-based shuffle SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). This change introduces new messages between clients and the external shuffle service: 1. `MergedBlockMetaRequest`: The client sends this to external shuffle to get the meta information for a merged block. The response to this is one of these : - `MergedBlockMetaSuccess` : contains request id, number of chunks, and a `ManagedBuffer` which is a `FileSegmentBuffer` backed by the merged block meta file. - `RpcFailure`: this is sent back to client in case of failure. This is an existing message. 2. `FetchShuffleBlockChunks`: This is similar to `FetchShuffleBlocks` message but it is to fetch merged shuffle chunks instead of blocks. ### Why are the changes needed? These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests. The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). We have already verified the functionality and the improved performance as documented in the SPIP doc. Lead-authored-by: Chandni Singh chsinghlinkedin.com Co-authored-by: Min Shen mshenlinkedin.com Closes #32811 from otterc/SPARK-35671. Lead-authored-by: Chandni Singh Co-authored-by: Min Shen Co-authored-by: Chandni Singh Signed-off-by: Mridul Muralidharan gmail.com> --- .../network/client/BaseResponseCallback.java | 31 ++++ .../MergedBlockMetaResponseCallback.java | 41 ++++++ .../network/client/RpcResponseCallback.java | 5 +- .../spark/network/client/TransportClient.java | 29 +++- .../client/TransportResponseHandler.java | 27 +++- .../spark/network/crypto/AuthRpcHandler.java | 5 + .../protocol/MergedBlockMetaRequest.java | 95 ++++++++++++ .../protocol/MergedBlockMetaSuccess.java | 92 ++++++++++++ .../spark/network/protocol/Message.java | 5 +- .../network/protocol/MessageDecoder.java | 6 + .../server/AbstractAuthRpcHandler.java | 5 + .../spark/network/server/RpcHandler.java | 44 ++++++ .../server/TransportRequestHandler.java | 26 ++++ .../network/TransportRequestHandlerSuite.java | 55 +++++++ .../TransportResponseHandlerSuite.java | 39 +++++ .../protocol/MergedBlockMetaSuccessSuite.java | 101 +++++++++++++ .../network/shuffle/BlockStoreClient.java | 20 +++ .../network/shuffle/ExternalBlockHandler.java | 139 +++++++++++++++--- .../shuffle/ExternalBlockStoreClient.java | 32 ++++ .../shuffle/ExternalShuffleBlockResolver.java | 4 +- .../shuffle/MergedBlocksMetaListener.java | 46 ++++++ .../shuffle/OneForOneBlockFetcher.java | 128 +++++++++++----- .../protocol/AbstractFetchShuffleBlocks.java | 88 +++++++++++ .../protocol/BlockTransferMessage.java | 4 +- .../protocol/FetchShuffleBlockChunks.java | 128 ++++++++++++++++ .../shuffle/protocol/FetchShuffleBlocks.java | 45 +++--- .../shuffle/ExternalBlockHandlerSuite.java | 112 ++++++++++++++ .../shuffle/OneForOneBlockFetcherSuite.java | 55 ++++++- .../FetchShuffleBlockChunksSuite.java | 42 ++++++ .../protocol/FetchShuffleBlocksSuite.java | 42 ++++++ .../ExternalShuffleServiceMetricsSuite.scala | 3 +- .../yarn/YarnShuffleServiceMetricsSuite.scala | 3 +- .../yarn/YarnShuffleServiceSuite.scala | 3 +- 33 files changed, 1398 insertions(+), 102 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java new file mode 100644 index 0000000000..d9b7fb2b3b --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A basic callback. This is extended by {@link RpcResponseCallback} and + * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests + * can be handled in {@link TransportResponseHandler} a similar way. + * + * @since 3.2.0 + */ +public interface BaseResponseCallback { + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java new file mode 100644 index 0000000000..60d010ef08 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single + * {@link org.apache.spark.network.protocol.MergedBlockMetaRequest}. + * + * @since 3.2.0 + */ +public interface MergedBlockMetaResponseCallback extends BaseResponseCallback { + /** + * Called upon receipt of a particular merged block meta. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + * + * @param numChunks number of merged chunks in the merged block + * @param buffer the buffer contains an array of roaring bitmaps. The i-th roaring bitmap + * contains the mapIds that were merged to the i-th merged chunk. + */ + void onSuccess(int numChunks, ManagedBuffer buffer); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6afc63f71b..a3b8cb1d90 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -23,7 +23,7 @@ import java.nio.ByteBuffer; * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ -public interface RpcResponseCallback { +public interface RpcResponseCallback extends BaseResponseCallback { /** * Successful serialized result from server. * @@ -31,7 +31,4 @@ public interface RpcResponseCallback { * Please copy the content of `response` if you want to use it after `onSuccess` returns. */ void onSuccess(ByteBuffer response); - - /** Exception either propagated from server or raised on client side. */ - void onFailure(Throwable e); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index eb2882074d..a50c04cf80 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -200,6 +200,31 @@ public class TransportClient implements Closeable { return requestId; } + /** + * Sends a MergedBlockMetaRequest message to the server. The response of this message is + * either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}. + * + * @param appId applicationId. + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param callback callback the handle the reply. + */ + public void sendMergedBlockMetaReq( + String appId, + int shuffleId, + int reduceId, + MergedBlockMetaResponseCallback callback) { + long requestId = requestId(); + if (logger.isTraceEnabled()) { + logger.trace( + "Sending RPC {} to fetch merged block meta to {}", requestId, getRemoteAddress(channel)); + } + handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush( + new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener); + } + /** * Send data to the remote end as a stream. This differs from stream() in that this is a request * to *send* data to the remote end, not to receive it from the remote. @@ -349,9 +374,9 @@ public class TransportClient implements Closeable { private class RpcChannelListener extends StdChannelListener { final long rpcRequestId; - final RpcResponseCallback callback; + final BaseResponseCallback callback; - RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + RpcChannelListener(long rpcRequestId, BaseResponseCallback callback) { super("RPC " + rpcRequestId); this.rpcRequestId = rpcRequestId; this.callback = callback; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 3aac2d2441..576c08858d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,6 +33,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -56,7 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingFetches; - private final Map outstandingRpcs; + private final Map outstandingRpcs; private final Queue> streamCallbacks; private volatile boolean streamActive; @@ -81,7 +82,7 @@ public class TransportResponseHandler extends MessageHandler { outstandingFetches.remove(streamChunkId); } - public void addRpcRequest(long requestId, RpcResponseCallback callback) { + public void addRpcRequest(long requestId, BaseResponseCallback callback) { updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -112,7 +113,7 @@ public class TransportResponseHandler extends MessageHandler { logger.warn("ChunkReceivedCallback.onFailure throws exception", e); } } - for (Map.Entry entry : outstandingRpcs.entrySet()) { + for (Map.Entry entry : outstandingRpcs.entrySet()) { try { entry.getValue().onFailure(cause); } catch (Exception e) { @@ -184,7 +185,7 @@ public class TransportResponseHandler extends MessageHandler { } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); @@ -199,7 +200,7 @@ public class TransportResponseHandler extends MessageHandler { } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + BaseResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.errorString); @@ -207,6 +208,22 @@ public class TransportResponseHandler extends MessageHandler { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } + } else if (message instanceof MergedBlockMetaSuccess) { + MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message; + try { + MergedBlockMetaResponseCallback listener = + (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn( + "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not" + + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onSuccess(resp.getNumChunks(), resp.body()); + } + } finally { + resp.body().release(); + } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; Pair entry = streamCallbacks.poll(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index dd31c95535..8f0a40c380 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -138,4 +138,9 @@ class AuthRpcHandler extends AbstractAuthRpcHandler { LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); return true; } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return saslHandler.getMergedBlockMetaReqHandler(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java new file mode 100644 index 0000000000..cf7c22d241 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +/** + * Request to find the meta information for the specified merged block. The meta information + * contains the number of chunks in the merged blocks and the maps ids in each chunk. + * + * @since 3.2.0 + */ +public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage { + public final long requestId; + public final String appId; + public final int shuffleId; + public final int reduceId; + + public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) { + super(null, false); + this.requestId = requestId; + this.appId = appId; + this.shuffleId = shuffleId; + this.reduceId = reduceId; + } + + @Override + public Type type() { + return Type.MergedBlockMetaRequest; + } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(reduceId); + } + + public static MergedBlockMetaRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int reduceId = buf.readInt(); + return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, appId, shuffleId, reduceId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof MergedBlockMetaRequest) { + MergedBlockMetaRequest o = (MergedBlockMetaRequest) other; + return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId + && Objects.equal(appId, o.appId); + } + return false; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("requestId", requestId) + .append("appId", appId) + .append("shuffleId", shuffleId) + .append("reduceId", reduceId) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java new file mode 100644 index 0000000000..d2edaf4532 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link MergedBlockMetaRequest} request. + * Note that the server-side encoding of this messages does NOT include the buffer itself. + * + * @since 3.2.0 + */ +public class MergedBlockMetaSuccess extends AbstractResponseMessage { + public final long requestId; + public final int numChunks; + + public MergedBlockMetaSuccess( + long requestId, + int numChunks, + ManagedBuffer chunkBitmapsBuffer) { + super(chunkBitmapsBuffer, true); + this.requestId = requestId; + this.numChunks = numChunks; + } + + @Override + public Type type() { + return Type.MergedBlockMetaSuccess; + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, numChunks); + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("requestId", requestId).append("numChunks", numChunks).toString(); + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(numChunks); + } + + public int getNumChunks() { + return numChunks; + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static MergedBlockMetaSuccess decode(ByteBuf buf) { + long requestId = buf.readLong(); + int numChunks = buf.readInt(); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new MergedBlockMetaSuccess(requestId, numChunks, managedBuf); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 0ccd70c03a..12ebee8da9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,8 @@ public interface Message extends Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), User(-1); + OneWayMessage(9), UploadStream(10), MergedBlockMetaRequest(11), MergedBlockMetaSuccess(12), + User(-1); private final byte id; @@ -66,6 +67,8 @@ public interface Message extends Encodable { case 8: return StreamFailure; case 9: return OneWayMessage; case 10: return UploadStream; + case 11: return MergedBlockMetaRequest; + case 12: return MergedBlockMetaSuccess; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0af..98f7f612a4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -83,6 +83,12 @@ public final class MessageDecoder extends MessageToMessageDecoder { case UploadStream: return UploadStream.decode(in); + case MergedBlockMetaRequest: + return MergedBlockMetaRequest.decode(in); + + case MergedBlockMetaSuccess: + return MergedBlockMetaSuccess.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 92eb886283..95fde67762 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -104,4 +104,9 @@ public abstract class AbstractAuthRpcHandler extends RpcHandler { public boolean isAuthenticated() { return isAuthenticated; } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return delegate.getMergedBlockMetaReqHandler(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 38569baf82..0b89427756 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -22,9 +22,11 @@ import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. @@ -32,6 +34,8 @@ import org.apache.spark.network.client.TransportClient; public abstract class RpcHandler { private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + private static final MergedBlockMetaReqHandler NOOP_MERGED_BLOCK_META_REQ_HANDLER = + new NoopMergedBlockMetaReqHandler(); /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to @@ -100,6 +104,10 @@ public abstract class RpcHandler { receive(client, message, ONE_WAY_CALLBACK); } + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return NOOP_MERGED_BLOCK_META_REQ_HANDLER; + } + /** * Invoked when the channel associated with the given client is active. */ @@ -129,4 +137,40 @@ public abstract class RpcHandler { } + /** + * Handler for {@link MergedBlockMetaRequest}. + * + * @since 3.2.0 + */ + public interface MergedBlockMetaReqHandler { + + /** + * Receive a {@link MergedBlockMetaRequest}. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. + * @param mergedBlockMetaRequest Request for merged block meta. + * @param callback Callback which should be invoked exactly once upon success or failure. + */ + void receiveMergeBlockMetaReq( + TransportClient client, + MergedBlockMetaRequest mergedBlockMetaRequest, + MergedBlockMetaResponseCallback callback); + } + + /** + * A Noop implementation of {@link MergedBlockMetaReqHandler}. This Noop implementation is used + * by all the RPC handlers which don't eventually delegate the {@link MergedBlockMetaRequest} to + * ExternalBlockHandler in the network-shuffle module. + * + * @since 3.2.0 + */ + private static class NoopMergedBlockMetaReqHandler implements MergedBlockMetaReqHandler { + + @Override + public void receiveMergeBlockMetaReq(TransportClient client, + MergedBlockMetaRequest mergedBlockMetaRequest, MergedBlockMetaResponseCallback callback) { + // do nothing + } + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4a30f8de07..ab2deac20f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -113,6 +113,8 @@ public class TransportRequestHandler extends MessageHandler { processStreamRequest((StreamRequest) request); } else if (request instanceof UploadStream) { processStreamUpload((UploadStream) request); + } else if (request instanceof MergedBlockMetaRequest) { + processMergedBlockMetaRequest((MergedBlockMetaRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -260,6 +262,30 @@ public class TransportRequestHandler extends MessageHandler { } } + private void processMergedBlockMetaRequest(final MergedBlockMetaRequest req) { + try { + rpcHandler.getMergedBlockMetaReqHandler().receiveMergeBlockMetaReq(reverseClient, req, + new MergedBlockMetaResponseCallback() { + + @Override + public void onSuccess(int numChunks, ManagedBuffer buffer) { + logger.trace("Sending meta for request {} numChunks {}", req, numChunks); + respond(new MergedBlockMetaSuccess(req.requestId, numChunks, buffer)); + } + + @Override + public void onFailure(Throwable e) { + logger.trace("Failed to send meta for {}", req); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking receiveMergeBlockMetaReq() for appId {} shuffleId {} " + + "reduceId {}", req.appId, req.shuffleId, req.appId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 0a64471762..b3befb8baf 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; @@ -24,16 +25,19 @@ import io.netty.channel.Channel; import org.junit.Assert; import org.junit.Test; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.*; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportRequestHandler; public class TransportRequestHandlerSuite { @@ -109,4 +113,55 @@ public class TransportRequestHandlerSuite { streamManager.connectionTerminated(channel); Assert.assertEquals(0, streamManager.numStreamStates()); } + + @Test + public void handleMergedBlockMetaRequest() throws Exception { + RpcHandler.MergedBlockMetaReqHandler metaHandler = (client, request, callback) -> { + if (request.shuffleId != -1 && request.reduceId != -1) { + callback.onSuccess(2, mock(ManagedBuffer.class)); + } else { + callback.onFailure(new RuntimeException("empty block")); + } + }; + RpcHandler rpcHandler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) {} + + @Override + public StreamManager getStreamManager() { + return null; + } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return metaHandler; + } + }; + Channel channel = mock(Channel.class); + List> responseAndPromisePairs = new ArrayList<>(); + when(channel.writeAndFlush(any())).thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + TransportClient reverseClient = mock(TransportClient.class); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, + rpcHandler, 2L, null); + MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0); + requestHandler.handle(validMetaReq); + assertEquals(1, responseAndPromisePairs.size()); + assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess); + assertEquals(2, + ((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks()); + + MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1); + requestHandler.handle(invalidMetaReq); + assertEquals(2, responseAndPromisePairs.size()); + assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index b4032c4c3f..4de13f951d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -23,17 +23,20 @@ import java.nio.ByteBuffer; import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; +import org.mockito.ArgumentCaptor; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; @@ -167,4 +170,40 @@ public class TransportResponseHandlerSuite { verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); } + + @Test + public void handleSuccessfulMergedBlockMeta() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.addRpcRequest(13, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new MergedBlockMetaSuccess(22, 2, + new NioManagedBuffer(ByteBuffer.allocate(7)))); + assertEquals(1, handler.numOutstandingRequests()); + + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new MergedBlockMetaSuccess(13, 2, new NioManagedBuffer(resp))); + ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(NioManagedBuffer.class); + verify(callback, times(1)).onSuccess(eq(2), bufferCaptor.capture()); + assertEquals(resp, bufferCaptor.getValue().nioByteBuffer()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedMergedBlockMeta() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.addRpcRequest(51, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new RpcFailure(6, "failed")); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(51, "failed")); + verify(callback, times(1)).onFailure(any()); + assertEquals(0, handler.numOutstandingRequests()); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java new file mode 100644 index 0000000000..f4a055188c --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.nio.file.Files; +import java.util.List; + +import com.google.common.collect.Lists; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.Assert; +import org.junit.Test; +import org.roaringbitmap.RoaringBitmap; + +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * Test for {@link MergedBlockMetaSuccess}. + */ +public class MergedBlockMetaSuccessSuite { + + @Test + public void testMergedBlocksMetaEncodeDecode() throws Exception { + File chunkMetaFile = new File("target/mergedBlockMetaTest"); + Files.deleteIfExists(chunkMetaFile.toPath()); + RoaringBitmap chunk1 = new RoaringBitmap(); + chunk1.add(1); + chunk1.add(3); + RoaringBitmap chunk2 = new RoaringBitmap(); + chunk2.add(2); + chunk2.add(4); + RoaringBitmap[] expectedChunks = new RoaringBitmap[]{chunk1, chunk2}; + try (DataOutputStream metaOutput = new DataOutputStream(new FileOutputStream(chunkMetaFile))) { + for (int i = 0; i < expectedChunks.length; i++) { + expectedChunks[i].serialize(metaOutput); + } + } + TransportConf conf = mock(TransportConf.class); + when(conf.lazyFileDescriptor()).thenReturn(false); + long requestId = 1L; + MergedBlockMetaSuccess expectedMeta = new MergedBlockMetaSuccess(requestId, 2, + new FileSegmentManagedBuffer(conf, chunkMetaFile, 0, chunkMetaFile.length())); + + List out = Lists.newArrayList(); + ChannelHandlerContext context = mock(ChannelHandlerContext.class); + when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT); + + MessageEncoder.INSTANCE.encode(context, expectedMeta, out); + Assert.assertEquals(1, out.size()); + MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0); + + ByteArrayWritableChannel writableChannel = + new ByteArrayWritableChannel((int) msgWithHeader.count()); + while (msgWithHeader.transfered() < msgWithHeader.count()) { + msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered()); + } + ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData()); + messageBuf.readLong(); // frame length + MessageDecoder.INSTANCE.decode(mock(ChannelHandlerContext.class), messageBuf, out); + Assert.assertEquals(1, out.size()); + MergedBlockMetaSuccess decoded = (MergedBlockMetaSuccess) out.get(0); + Assert.assertEquals("merged block", expectedMeta.requestId, decoded.requestId); + Assert.assertEquals("num chunks", expectedMeta.getNumChunks(), decoded.getNumChunks()); + + ByteBuf responseBuf = Unpooled.wrappedBuffer(decoded.body().nioByteBuffer()); + RoaringBitmap[] responseBitmaps = new RoaringBitmap[expectedMeta.getNumChunks()]; + for (int i = 0; i < expectedMeta.getNumChunks(); i++) { + responseBitmaps[i] = Encoders.Bitmaps.decode(responseBuf); + } + Assert.assertEquals( + "num of roaring bitmaps", expectedMeta.getNumChunks(), responseBitmaps.length); + for (int i = 0; i < expectedMeta.getNumChunks(); i++) { + Assert.assertEquals("chunk bitmap " + i, expectedChunks[i], responseBitmaps[i]); + } + Files.delete(chunkMetaFile.toPath()); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index a6bdc13e93..238d26ee50 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -178,4 +178,24 @@ public abstract class BlockStoreClient implements Closeable { MergeFinalizerListener listener) { throw new UnsupportedOperationException(); } + + /** + * Get the meta information of a merged block from the remote shuffle service. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param listener the listener to receive chunk counts. + * + * @since 3.2.0 + */ + public void getMergedBlockMeta( + String host, + int port, + int shuffleId, + int reduceId, + MergedBlocksMetaListener listener) { + throw new UnsupportedOperationException(); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 09e29cafa9..c5f5834c0b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -17,12 +17,14 @@ package org.apache.spark.network.shuffle; +import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; import java.util.function.Function; import com.codahale.metrics.Gauge; @@ -31,14 +33,17 @@ import com.codahale.metrics.Metric; import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.codahale.metrics.Counter; +import com.google.common.collect.Sets; import com.google.common.annotations.VisibleForTesting; -import org.apache.spark.network.client.StreamCallbackWithID; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -54,8 +59,12 @@ import org.apache.spark.network.util.TransportConf; * Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk * is equivalent to one block. */ -public class ExternalBlockHandler extends RpcHandler { +public class ExternalBlockHandler extends RpcHandler + implements RpcHandler.MergedBlockMetaReqHandler { private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class); + private static final String SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger"; + private static final String SHUFFLE_BLOCK_ID = "shuffle"; + private static final String SHUFFLE_CHUNK_ID = "shuffleChunk"; @VisibleForTesting final ExternalShuffleBlockResolver blockManager; @@ -128,24 +137,23 @@ public class ExternalBlockHandler extends RpcHandler { BlockTransferMessage msgObj, TransportClient client, RpcResponseCallback callback) { - if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) { + if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) { final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time(); try { int numBlockIds; long streamId; - if (msgObj instanceof FetchShuffleBlocks) { - FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj; + if (msgObj instanceof AbstractFetchShuffleBlocks) { + AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj; checkAuth(client, msg.appId); - numBlockIds = 0; - if (msg.batchFetchEnabled) { - numBlockIds = msg.mapIds.length; + numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks(); + Iterator iterator; + if (msgObj instanceof FetchShuffleBlocks) { + iterator = new ShuffleManagedBufferIterator((FetchShuffleBlocks)msgObj); } else { - for (int[] ids: msg.reduceIds) { - numBlockIds += ids.length; - } + iterator = new ShuffleChunkManagedBufferIterator((FetchShuffleBlockChunks) msgObj); } - streamId = streamManager.registerStream(client.getClientId(), - new ShuffleManagedBufferIterator(msg), client.getChannel()); + streamId = streamManager.registerStream(client.getClientId(), iterator, + client.getChannel()); } else { // For the compatibility with the old version, still keep the support for OpenBlocks. OpenBlocks msg = (OpenBlocks) msgObj; @@ -189,9 +197,14 @@ public class ExternalBlockHandler extends RpcHandler { } else if (msgObj instanceof GetLocalDirsForExecutors) { GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj; checkAuth(client, msg.appId); - Map localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds); + Set execIdsForBlockResolver = Sets.newHashSet(msg.execIds); + boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER); + Map localDirs = blockManager.getLocalDirs(msg.appId, + execIdsForBlockResolver); + if (fetchMergedBlockDirs) { + localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId)); + } callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer()); - } else if (msgObj instanceof FinalizeShuffleMerge) { final Timer.Context responseDelayContext = metrics.finalizeShuffleMergeLatencyMillis.time(); @@ -211,6 +224,32 @@ public class ExternalBlockHandler extends RpcHandler { } } + @Override + public void receiveMergeBlockMetaReq( + TransportClient client, + MergedBlockMetaRequest metaRequest, + MergedBlockMetaResponseCallback callback) { + final Timer.Context responseDelayContext = metrics.fetchMergedBlocksMetaLatencyMillis.time(); + try { + checkAuth(client, metaRequest.appId); + MergedBlockMeta mergedMeta = + mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId, + metaRequest.reduceId); + logger.debug( + "Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ", + metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId, + mergedMeta.getNumChunks()); + callback.onSuccess(mergedMeta.getNumChunks(), mergedMeta.getChunksBitmapBuffer()); + } finally { + responseDelayContext.stop(); + } + } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return this; + } + @Override public void exceptionCaught(Throwable cause, TransportClient client) { metrics.caughtExceptions.inc(); @@ -262,6 +301,8 @@ public class ExternalBlockHandler extends RpcHandler { private final Timer openBlockRequestLatencyMillis = new Timer(); // Time latency for executor registration latency in ms private final Timer registerExecutorRequestLatencyMillis = new Timer(); + // Time latency for processing fetch merged blocks meta request latency in ms + private final Timer fetchMergedBlocksMetaLatencyMillis = new Timer(); // Time latency for processing finalize shuffle merge request latency in ms private final Timer finalizeShuffleMergeLatencyMillis = new Timer(); // Block transfer rate in byte per second @@ -275,6 +316,7 @@ public class ExternalBlockHandler extends RpcHandler { allMetrics = new HashMap<>(); allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis); allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis); + allMetrics.put("fetchMergedBlocksMetaLatencyMillis", fetchMergedBlocksMetaLatencyMillis); allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis); allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); allMetrics.put("registeredExecutorsSize", @@ -294,18 +336,26 @@ public class ExternalBlockHandler extends RpcHandler { private int index = 0; private final Function blockDataForIndexFn; private final int size; + private boolean requestForMergedBlockChunks; ManagedBufferIterator(OpenBlocks msg) { String appId = msg.appId; String execId = msg.execId; String[] blockIds = msg.blockIds; String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) { + if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_BLOCK_ID)) { final int shuffleId = Integer.parseInt(blockId0Parts[1]); final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); size = mapIdAndReduceIds.length; blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + } else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) { + requestForMergedBlockChunks = true; + final int shuffleId = Integer.parseInt(blockId0Parts[1]); + final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); + size = reduceIdAndChunkIds.length; + blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId, + reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]); } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) { final int[] rddAndSplitIds = rddAndSplitIds(blockIds); size = rddAndSplitIds.length; @@ -330,20 +380,26 @@ public class ExternalBlockHandler extends RpcHandler { } private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { - final int[] mapIdAndReduceIds = new int[2 * blockIds.length]; + // For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds. + // For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds. + final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + if (blockIdParts.length != 4 + || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID)) + || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); } if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockIds[i]); } - mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); - mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + // For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId. + primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]); + // For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId. + primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); } - return mapIdAndReduceIds; + return primaryIdAndSecondaryIds; } @Override @@ -413,6 +469,47 @@ public class ExternalBlockHandler extends RpcHandler { } } + private class ShuffleChunkManagedBufferIterator implements Iterator { + + private int reduceIdx = 0; + private int chunkIdx = 0; + + private final String appId; + private final int shuffleId; + private final int[] reduceIds; + private final int[][] chunkIds; + + ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) { + appId = msg.appId; + shuffleId = msg.shuffleId; + reduceIds = msg.reduceIds; + chunkIds = msg.chunkIds; + // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks + // must have non-empty reduceIds and chunkIds, see the checking logic in + // OneForOneBlockFetcher. + assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length); + } + + @Override + public boolean hasNext() { + return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length; + } + + @Override + public ManagedBuffer next() { + ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData( + appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx])); + if (chunkIdx < chunkIds[reduceIdx].length - 1) { + chunkIdx += 1; + } else { + chunkIdx = 0; + reduceIdx += 1; + } + metrics.blockTransferRateBytes.mark(block.size()); + return block; + } + } + /** * Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle * is not enabled. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 56c06e640a..f44140b124 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; @@ -187,6 +188,37 @@ public class ExternalBlockStoreClient extends BlockStoreClient { } } + @Override + public void getMergedBlockMeta( + String host, + int port, + int shuffleId, + int reduceId, + MergedBlocksMetaListener listener) { + checkInit(); + logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port, + shuffleId, reduceId); + try { + TransportClient client = clientFactory.createClient(host, port); + client.sendMergedBlockMetaReq(appId, shuffleId, reduceId, + new MergedBlockMetaResponseCallback() { + @Override + public void onSuccess(int numChunks, ManagedBuffer buffer) { + logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}", + shuffleId, reduceId); + listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer)); + } + + @Override + public void onFailure(Throwable e) { + listener.onFailure(shuffleId, reduceId, e); + } + }); + } catch (Exception e) { + listener.onFailure(shuffleId, reduceId, e); + } + } + @Override public MetricSet shuffleMetrics() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index a095bf2723..493edd2b34 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -361,8 +361,8 @@ public class ExternalShuffleBlockResolver { return numRemovedBlocks; } - public Map getLocalDirs(String appId, String[] execIds) { - return Arrays.stream(execIds) + public Map getLocalDirs(String appId, Set execIds) { + return execIds.stream() .map(exec -> { ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec)); if (info == null) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java new file mode 100644 index 0000000000..0e277d3303 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.EventListener; + +/** + * Listener for receiving success or failure events when fetching meta of merged blocks. + * + * @since 3.2.0 + */ +public interface MergedBlocksMetaListener extends EventListener { + + /** + * Called after successfully receiving the meta of a merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param meta contains meta information of a merged block. + */ + void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta); + + /** + * Called when there is an exception while fetching the meta of a merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param exception exception getting chunk counts. + */ + void onFailure(int shuffleId, int reduceId, Throwable exception); +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 0b7eaa6225..2bf16b0097 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; +import java.util.Set; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; @@ -34,8 +35,10 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.shuffle.protocol.AbstractFetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.TransportConf; @@ -51,6 +54,8 @@ import org.apache.spark.network.util.TransportConf; */ public class OneForOneBlockFetcher { private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); + private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_"; + private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_"; private final TransportClient client; private final BlockTransferMessage message; @@ -88,74 +93,113 @@ public class OneForOneBlockFetcher { if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { + if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we + // check if all the block ids are shuffle chunk Ids. + return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by - * analyzing the pass in blockIds. + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by + * analyzing the passed in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, + id -> new BlocksInfo()); + blocksInfoByPrimaryId.blockIds.add(blockId); + // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a + // shuffleChunk block, then blockIdParts[3] = chunkId + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { + // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block. // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + // blockIdParts[4] is the end reduce id for the batch range + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) { + secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids); - // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks - // because the shuffle data's return order should match the `blockIds`'s order to ensure - // blockId and data match. - for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { - this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ + // FetchShuffleBlockChunks because the shuffle data's return order should match the + // `blockIds`'s order to ensure blockId and data match. + for (String blockId : blocksInfo.blockIds) { + this.blockIds[blockIdIndex++] = blockId; } } assert(blockIdIndex == this.blockIds.length); - - return new FetchShuffleBlocks( - appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); + Set primaryIds = primaryIdToBlocksInfo.keySet(); + if (!areMergedChunks) { + long[] mapIds = Longs.toArray(primaryIds); + return new FetchShuffleBlocks( + appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled); + } else { + int[] reduceIds = Ints.toArray(primaryIds); + return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray); + } } /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */ @@ -163,7 +207,17 @@ public class OneForOneBlockFetcher { String[] blockIdParts = blockId.split("_"); // For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId. // For single block id, the format contains shuffleId, mapId, educeId. - if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) { + // For single block chunk id, the format contains shuffleId, reduceId, chunkId. + if (blockIdParts.length < 4 || blockIdParts.length > 5) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + if (blockIdParts.length == 4 && + !(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) { throw new IllegalArgumentException( "Unexpected shuffle block id format: " + blockId); } @@ -173,11 +227,15 @@ public class OneForOneBlockFetcher { /** The reduceIds and blocks in a single mapId */ private class BlocksInfo { - final ArrayList reduceIds; + /** + * For {@link FetchShuffleBlocks} message, the ids are reduceIds. + * For {@link FetchShuffleBlockChunks} message, the ids are chunkIds. + */ + final ArrayList ids; final ArrayList blockIds; BlocksInfo() { - this.reduceIds = new ArrayList<>(); + this.ids = new ArrayList<>(); this.blockIds = new ArrayList<>(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java new file mode 100644 index 0000000000..0fca27cf26 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** + * Base class for fetch shuffle blocks and chunks. + * + * @since 3.2.0 + */ +public abstract class AbstractFetchShuffleBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + + protected AbstractFetchShuffleBlocks( + String appId, + String execId, + int shuffleId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + } + + public ToStringBuilder toStringHelper() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("shuffleId", shuffleId); + } + + /** + * Returns number of blocks in the request. + */ + public abstract int getNumBlocks(); + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AbstractFetchShuffleBlocks that = (AbstractFetchShuffleBlocks) o; + return shuffleId == that.shuffleId + && Objects.equal(appId, that.appId) && Objects.equal(execId, that.execId); + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + shuffleId; + return result; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 4; /* encoded length of shuffleId */ + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 7f50581249..a55a6cf7ed 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -48,7 +48,8 @@ public abstract class BlockTransferMessage implements Encodable { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), - PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14); + PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), + FETCH_SHUFFLE_BLOCK_CHUNKS(15); private final byte id; @@ -82,6 +83,7 @@ public abstract class BlockTransferMessage implements Encodable { case 12: return PushBlockStream.decode(buf); case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); + case 15: return FetchShuffleBlockChunks.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java new file mode 100644 index 0000000000..27345dd8e7 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + + +/** + * Request to read a set of block chunks. Returns {@link StreamHandle}. + * + * @since 3.2.0 + */ +public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { + // The length of reduceIds must equal to chunkIds.size(). + public final int[] reduceIds; + // The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds. + public final int[][] chunkIds; + + public FetchShuffleBlockChunks( + String appId, + String execId, + int shuffleId, + int[] reduceIds, + int[][] chunkIds) { + super(appId, execId, shuffleId); + this.reduceIds = reduceIds; + this.chunkIds = chunkIds; + assert(reduceIds.length == chunkIds.length); + } + + @Override + protected Type type() { return Type.FETCH_SHUFFLE_BLOCK_CHUNKS; } + + @Override + public String toString() { + return toStringHelper() + .append("reduceIds", Arrays.toString(reduceIds)) + .append("chunkIds", Arrays.deepToString(chunkIds)) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o; + if (!super.equals(that)) return false; + if (!Arrays.equals(reduceIds, that.reduceIds)) return false; + return Arrays.deepEquals(chunkIds, that.chunkIds); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Arrays.hashCode(reduceIds); + result = 31 * result + Arrays.deepHashCode(chunkIds); + return result; + } + + @Override + public int encodedLength() { + int encodedLengthOfChunkIds = 0; + for (int[] ids: chunkIds) { + encodedLengthOfChunkIds += Encoders.IntArrays.encodedLength(ids); + } + return super.encodedLength() + + Encoders.IntArrays.encodedLength(reduceIds) + + 4 /* encoded length of chunkIds.size() */ + + encodedLengthOfChunkIds; + } + + @Override + public void encode(ByteBuf buf) { + super.encode(buf); + Encoders.IntArrays.encode(buf, reduceIds); + // Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the + // interest of forward compatibility. + buf.writeInt(chunkIds.length); + for (int[] ids: chunkIds) { + Encoders.IntArrays.encode(buf, ids); + } + } + + @Override + public int getNumBlocks() { + int numBlocks = 0; + for (int[] ids : chunkIds) { + numBlocks += ids.length; + } + return numBlocks; + } + + public static FetchShuffleBlockChunks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int[] reduceIds = Encoders.IntArrays.decode(buf); + int chunkIdsLen = buf.readInt(); + int[][] chunkIds = new int[chunkIdsLen][]; + for (int i = 0; i < chunkIdsLen; i++) { + chunkIds[i] = Encoders.IntArrays.decode(buf); + } + return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 98057d58f7..68550a2fba 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -20,8 +20,6 @@ package org.apache.spark.network.shuffle.protocol; import java.util.Arrays; import io.netty.buffer.ByteBuf; -import org.apache.commons.lang3.builder.ToStringBuilder; -import org.apache.commons.lang3.builder.ToStringStyle; import org.apache.spark.network.protocol.Encoders; @@ -29,10 +27,7 @@ import org.apache.spark.network.protocol.Encoders; import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ -public class FetchShuffleBlocks extends BlockTransferMessage { - public final String appId; - public final String execId; - public final int shuffleId; +public class FetchShuffleBlocks extends AbstractFetchShuffleBlocks { // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. public final long[] mapIds; @@ -50,9 +45,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { long[] mapIds, int[][] reduceIds, boolean batchFetchEnabled) { - this.appId = appId; - this.execId = execId; - this.shuffleId = shuffleId; + super(appId, execId, shuffleId); this.mapIds = mapIds; this.reduceIds = reduceIds; assert(mapIds.length == reduceIds.length); @@ -69,10 +62,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { @Override public String toString() { - return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) - .append("appId", appId) - .append("execId", execId) - .append("shuffleId", shuffleId) + return toStringHelper() .append("mapIds", Arrays.toString(mapIds)) .append("reduceIds", Arrays.deepToString(reduceIds)) .append("batchFetchEnabled", batchFetchEnabled) @@ -85,35 +75,40 @@ public class FetchShuffleBlocks extends BlockTransferMessage { if (o == null || getClass() != o.getClass()) return false; FetchShuffleBlocks that = (FetchShuffleBlocks) o; - - if (shuffleId != that.shuffleId) return false; + if (!super.equals(that)) return false; if (batchFetchEnabled != that.batchFetchEnabled) return false; - if (!appId.equals(that.appId)) return false; - if (!execId.equals(that.execId)) return false; if (!Arrays.equals(mapIds, that.mapIds)) return false; return Arrays.deepEquals(reduceIds, that.reduceIds); } @Override public int hashCode() { - int result = appId.hashCode(); - result = 31 * result + execId.hashCode(); - result = 31 * result + shuffleId; + int result = super.hashCode(); result = 31 * result + Arrays.hashCode(mapIds); result = 31 * result + Arrays.deepHashCode(reduceIds); result = 31 * result + (batchFetchEnabled ? 1 : 0); return result; } + @Override + public int getNumBlocks() { + if (batchFetchEnabled) { + return mapIds.length; + } + int numBlocks = 0; + for (int[] ids : reduceIds) { + numBlocks += ids.length; + } + return numBlocks; + } + @Override public int encodedLength() { int encodedLengthOfReduceIds = 0; for (int[] ids: reduceIds) { encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids); } - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + 4 /* encoded length of shuffleId */ + return super.encodedLength() + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds @@ -122,9 +117,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { @Override public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - buf.writeInt(shuffleId); + super.encode(buf); Encoders.LongArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index f06e7cb047..bad61d30d7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -34,13 +34,16 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -258,4 +261,113 @@ public class ExternalBlockHandlerSuite { .get("finalizeShuffleMergeLatencyMillis"); assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount()); } + + @Test + public void testFetchMergedBlocksMeta() { + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn( + new MergedBlockMeta(1, mock(ManagedBuffer.class))); + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn( + new MergedBlockMeta(3, mock(ManagedBuffer.class))); + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn( + new MergedBlockMeta(5, mock(ManagedBuffer.class))); + + int[] expectedCount = new int[]{1, 3, 5}; + String appId = "app0"; + long requestId = 0L; + for (int reduceId = 0; reduceId < 3; reduceId++) { + MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.getMergedBlockMetaReqHandler() + .receiveMergeBlockMetaReq(client, req, callback); + verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId); + + ArgumentCaptor numChunksResponse = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor chunkBitmapResponse = + ArgumentCaptor.forClass(ManagedBuffer.class); + verify(callback, times(1)).onSuccess(numChunksResponse.capture(), + chunkBitmapResponse.capture()); + assertEquals("num chunks in merged block " + reduceId, expectedCount[reduceId], + numChunksResponse.getValue().intValue()); + assertNotNull("chunks bitmap buffer " + reduceId, chunkBitmapResponse.getValue()); + } + } + + @Test + public void testOpenBlocksWithShuffleChunks() { + verifyBlockChunkFetches(true); + } + + @Test + public void testFetchShuffleChunks() { + verifyBlockChunkFetches(false); + } + + private void verifyBlockChunkFetches(boolean useOpenBlocks) { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + ByteBuffer buffer; + if (useOpenBlocks) { + OpenBlocks openBlocks = + new OpenBlocks("app0", "exec1", + new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0", + "shuffleChunk_0_1_1"}); + buffer = openBlocks.toByteBuffer(); + } else { + FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks( + "app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}}); + buffer = fetchChunks.toByteBuffer(); + } + ManagedBuffer[][] buffers = new ManagedBuffer[][] { + { + new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + }, + { + new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + } + }; + for (int reduceId = 0; reduceId < 2; reduceId++) { + for (int chunkId = 0; chunkId < 2; chunkId++) { + when(mergedShuffleManager.getMergedBlockData( + "app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]); + } + } + handler.receive(client, buffer, callback); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure(any()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); + assertEquals(4, handle.numChunks); + + @SuppressWarnings("unchecked") + ArgumentCaptor> stream = (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(any(), stream.capture(), any()); + Iterator bufferIter = stream.getValue(); + for (int reduceId = 0; reduceId < 2; reduceId++) { + for (int chunkId = 0; chunkId < 2; chunkId++) { + assertEquals(buffers[reduceId][chunkId], bufferIter.next()); + } + } + assertFalse(bufferIter.hasNext()); + verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt()); + verify(blockResolver, never()).getBlockData( + anyString(), anyString(), anyInt(), anyInt(), anyInt()); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1); + + // Verify open block request latency metrics + Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler) + .getAllMetrics() + .getMetrics() + .get("openBlockRequestLatencyMillis"); + assertEquals(1, openBlockRequestLatencyMillis.getCount()); + // Verify block transfer metrics + Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler) + .getAllMetrics() + .getMetrics() + .get("blockTransferRateBytes"); + assertEquals(24, blockTransferRateBytes.getCount()); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index a7eb59d366..c4967eab31 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -27,8 +27,7 @@ import com.google.common.collect.Maps; import io.netty.buffer.Unpooled; import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; @@ -46,6 +45,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.MapConfigProvider; @@ -243,6 +243,57 @@ public class OneForOneBlockFetcherSuite { } } + @Test + public void testShuffleBlockChunksFetch() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks(blocks, blockIds, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0, 1, 2 }}), conf); + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i, + blocks.get("shuffleChunk_0_0_" + i)); + } + } + + @Test + public void testShuffleBlockChunkFetchFailure() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffleChunk_0_0_1", null); + blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks(blocks, blockIds, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}), + conf); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0", + blocks.get("shuffleChunk_0_0_0")); + verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any()); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2", + blocks.get("shuffleChunk_0_0_2")); + } + + @Test + public void testInvalidShuffleBlockIds() { + assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), + new String[]{"shuffle_0_0"}, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0 }}), conf)); + assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), + new String[]{"shuffleChunk_0_0_0_0_0"}, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0 }}), conf)); + assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), + new String[]{"shuffleChunk_0_0_0_0"}, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0 }}), conf)); + } + /** * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which * simply returns the given (BlockId, Block) pairs. diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java new file mode 100644 index 0000000000..91f319ded4 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class FetchShuffleBlockChunksSuite { + + @Test + public void testFetchShuffleBlockChunksEncodeDecode() { + FetchShuffleBlockChunks shuffleBlockChunks = + new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}}); + Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks()); + int len = shuffleBlockChunks.encodedLength(); + Assert.assertEquals(45, len); + ByteBuf buf = Unpooled.buffer(len); + shuffleBlockChunks.encode(buf); + + FetchShuffleBlockChunks decoded = FetchShuffleBlockChunks.decode(buf); + assertEquals(shuffleBlockChunks, decoded); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java new file mode 100644 index 0000000000..a1681f58e7 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class FetchShuffleBlocksSuite { + + @Test + public void testFetchShuffleBlockEncodeDecode() { + FetchShuffleBlocks fetchShuffleBlocks = + new FetchShuffleBlocks("app0", "exec1", 0, new long[] {0}, new int[][] {{0, 1}}, false); + Assert.assertEquals(2, fetchShuffleBlocks.getNumBlocks()); + int len = fetchShuffleBlocks.encodedLength(); + Assert.assertEquals(50, len); + ByteBuf buf = Unpooled.buffer(len); + fetchShuffleBlocks.encode(buf); + + FetchShuffleBlocks decoded = FetchShuffleBlocks.decode(buf); + assertEquals(fetchShuffleBlocks, decoded); + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala index ea4d252f0d..4ce46156c0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala @@ -62,7 +62,8 @@ class ExternalShuffleServiceMetricsSuite extends SparkFunSuite { "registerExecutorRequestLatencyMillis", "shuffle-server.usedDirectMemory", "shuffle-server.usedHeapMemory", - "finalizeShuffleMergeLatencyMillis") + "finalizeShuffleMergeLatencyMillis", + "fetchMergedBlocksMetaLatencyMillis") ) } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index 9239d891aa..d866fac726 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -40,7 +40,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { val allMetrics = Set( "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", - "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis") + "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis", + "fetchMergedBlocksMetaLatencyMillis") metrics.getMetrics.keySet().asScala should be (allMetrics) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index d6d1715223..afe85b3d40 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -412,7 +412,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd "registerExecutorRequestLatencyMillis", "finalizeShuffleMergeLatencyMillis", "shuffle-server.usedDirectMemory", - "shuffle-server.usedHeapMemory" + "shuffle-server.usedHeapMemory", + "fetchMergedBlocksMetaLatencyMillis" )) }