diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java new file mode 100644 index 0000000000..d9b7fb2b3b --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A basic callback. This is extended by {@link RpcResponseCallback} and + * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests + * can be handled in {@link TransportResponseHandler} a similar way. + * + * @since 3.2.0 + */ +public interface BaseResponseCallback { + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java new file mode 100644 index 0000000000..60d010ef08 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single + * {@link org.apache.spark.network.protocol.MergedBlockMetaRequest}. + * + * @since 3.2.0 + */ +public interface MergedBlockMetaResponseCallback extends BaseResponseCallback { + /** + * Called upon receipt of a particular merged block meta. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + * + * @param numChunks number of merged chunks in the merged block + * @param buffer the buffer contains an array of roaring bitmaps. The i-th roaring bitmap + * contains the mapIds that were merged to the i-th merged chunk. + */ + void onSuccess(int numChunks, ManagedBuffer buffer); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6afc63f71b..a3b8cb1d90 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -23,7 +23,7 @@ import java.nio.ByteBuffer; * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ -public interface RpcResponseCallback { +public interface RpcResponseCallback extends BaseResponseCallback { /** * Successful serialized result from server. * @@ -31,7 +31,4 @@ public interface RpcResponseCallback { * Please copy the content of `response` if you want to use it after `onSuccess` returns. */ void onSuccess(ByteBuffer response); - - /** Exception either propagated from server or raised on client side. */ - void onFailure(Throwable e); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index eb2882074d..a50c04cf80 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -200,6 +200,31 @@ public class TransportClient implements Closeable { return requestId; } + /** + * Sends a MergedBlockMetaRequest message to the server. The response of this message is + * either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}. + * + * @param appId applicationId. + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param callback callback the handle the reply. + */ + public void sendMergedBlockMetaReq( + String appId, + int shuffleId, + int reduceId, + MergedBlockMetaResponseCallback callback) { + long requestId = requestId(); + if (logger.isTraceEnabled()) { + logger.trace( + "Sending RPC {} to fetch merged block meta to {}", requestId, getRemoteAddress(channel)); + } + handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush( + new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener); + } + /** * Send data to the remote end as a stream. This differs from stream() in that this is a request * to *send* data to the remote end, not to receive it from the remote. @@ -349,9 +374,9 @@ public class TransportClient implements Closeable { private class RpcChannelListener extends StdChannelListener { final long rpcRequestId; - final RpcResponseCallback callback; + final BaseResponseCallback callback; - RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + RpcChannelListener(long rpcRequestId, BaseResponseCallback callback) { super("RPC " + rpcRequestId); this.rpcRequestId = rpcRequestId; this.callback = callback; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 3aac2d2441..576c08858d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,6 +33,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -56,7 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingFetches; - private final Map outstandingRpcs; + private final Map outstandingRpcs; private final Queue> streamCallbacks; private volatile boolean streamActive; @@ -81,7 +82,7 @@ public class TransportResponseHandler extends MessageHandler { outstandingFetches.remove(streamChunkId); } - public void addRpcRequest(long requestId, RpcResponseCallback callback) { + public void addRpcRequest(long requestId, BaseResponseCallback callback) { updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -112,7 +113,7 @@ public class TransportResponseHandler extends MessageHandler { logger.warn("ChunkReceivedCallback.onFailure throws exception", e); } } - for (Map.Entry entry : outstandingRpcs.entrySet()) { + for (Map.Entry entry : outstandingRpcs.entrySet()) { try { entry.getValue().onFailure(cause); } catch (Exception e) { @@ -184,7 +185,7 @@ public class TransportResponseHandler extends MessageHandler { } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); @@ -199,7 +200,7 @@ public class TransportResponseHandler extends MessageHandler { } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + BaseResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.errorString); @@ -207,6 +208,22 @@ public class TransportResponseHandler extends MessageHandler { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } + } else if (message instanceof MergedBlockMetaSuccess) { + MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message; + try { + MergedBlockMetaResponseCallback listener = + (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn( + "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not" + + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onSuccess(resp.getNumChunks(), resp.body()); + } + } finally { + resp.body().release(); + } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; Pair entry = streamCallbacks.poll(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index dd31c95535..8f0a40c380 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -138,4 +138,9 @@ class AuthRpcHandler extends AbstractAuthRpcHandler { LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); return true; } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return saslHandler.getMergedBlockMetaReqHandler(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java new file mode 100644 index 0000000000..cf7c22d241 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +/** + * Request to find the meta information for the specified merged block. The meta information + * contains the number of chunks in the merged blocks and the maps ids in each chunk. + * + * @since 3.2.0 + */ +public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage { + public final long requestId; + public final String appId; + public final int shuffleId; + public final int reduceId; + + public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) { + super(null, false); + this.requestId = requestId; + this.appId = appId; + this.shuffleId = shuffleId; + this.reduceId = reduceId; + } + + @Override + public Type type() { + return Type.MergedBlockMetaRequest; + } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(reduceId); + } + + public static MergedBlockMetaRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int reduceId = buf.readInt(); + return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, appId, shuffleId, reduceId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof MergedBlockMetaRequest) { + MergedBlockMetaRequest o = (MergedBlockMetaRequest) other; + return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId + && Objects.equal(appId, o.appId); + } + return false; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("requestId", requestId) + .append("appId", appId) + .append("shuffleId", shuffleId) + .append("reduceId", reduceId) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java new file mode 100644 index 0000000000..d2edaf4532 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link MergedBlockMetaRequest} request. + * Note that the server-side encoding of this messages does NOT include the buffer itself. + * + * @since 3.2.0 + */ +public class MergedBlockMetaSuccess extends AbstractResponseMessage { + public final long requestId; + public final int numChunks; + + public MergedBlockMetaSuccess( + long requestId, + int numChunks, + ManagedBuffer chunkBitmapsBuffer) { + super(chunkBitmapsBuffer, true); + this.requestId = requestId; + this.numChunks = numChunks; + } + + @Override + public Type type() { + return Type.MergedBlockMetaSuccess; + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, numChunks); + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("requestId", requestId).append("numChunks", numChunks).toString(); + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(numChunks); + } + + public int getNumChunks() { + return numChunks; + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static MergedBlockMetaSuccess decode(ByteBuf buf) { + long requestId = buf.readLong(); + int numChunks = buf.readInt(); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new MergedBlockMetaSuccess(requestId, numChunks, managedBuf); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 0ccd70c03a..12ebee8da9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,8 @@ public interface Message extends Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), User(-1); + OneWayMessage(9), UploadStream(10), MergedBlockMetaRequest(11), MergedBlockMetaSuccess(12), + User(-1); private final byte id; @@ -66,6 +67,8 @@ public interface Message extends Encodable { case 8: return StreamFailure; case 9: return OneWayMessage; case 10: return UploadStream; + case 11: return MergedBlockMetaRequest; + case 12: return MergedBlockMetaSuccess; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0af..98f7f612a4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -83,6 +83,12 @@ public final class MessageDecoder extends MessageToMessageDecoder { case UploadStream: return UploadStream.decode(in); + case MergedBlockMetaRequest: + return MergedBlockMetaRequest.decode(in); + + case MergedBlockMetaSuccess: + return MergedBlockMetaSuccess.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 92eb886283..95fde67762 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -104,4 +104,9 @@ public abstract class AbstractAuthRpcHandler extends RpcHandler { public boolean isAuthenticated() { return isAuthenticated; } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return delegate.getMergedBlockMetaReqHandler(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 38569baf82..0b89427756 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -22,9 +22,11 @@ import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. @@ -32,6 +34,8 @@ import org.apache.spark.network.client.TransportClient; public abstract class RpcHandler { private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + private static final MergedBlockMetaReqHandler NOOP_MERGED_BLOCK_META_REQ_HANDLER = + new NoopMergedBlockMetaReqHandler(); /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to @@ -100,6 +104,10 @@ public abstract class RpcHandler { receive(client, message, ONE_WAY_CALLBACK); } + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return NOOP_MERGED_BLOCK_META_REQ_HANDLER; + } + /** * Invoked when the channel associated with the given client is active. */ @@ -129,4 +137,40 @@ public abstract class RpcHandler { } + /** + * Handler for {@link MergedBlockMetaRequest}. + * + * @since 3.2.0 + */ + public interface MergedBlockMetaReqHandler { + + /** + * Receive a {@link MergedBlockMetaRequest}. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. + * @param mergedBlockMetaRequest Request for merged block meta. + * @param callback Callback which should be invoked exactly once upon success or failure. + */ + void receiveMergeBlockMetaReq( + TransportClient client, + MergedBlockMetaRequest mergedBlockMetaRequest, + MergedBlockMetaResponseCallback callback); + } + + /** + * A Noop implementation of {@link MergedBlockMetaReqHandler}. This Noop implementation is used + * by all the RPC handlers which don't eventually delegate the {@link MergedBlockMetaRequest} to + * ExternalBlockHandler in the network-shuffle module. + * + * @since 3.2.0 + */ + private static class NoopMergedBlockMetaReqHandler implements MergedBlockMetaReqHandler { + + @Override + public void receiveMergeBlockMetaReq(TransportClient client, + MergedBlockMetaRequest mergedBlockMetaRequest, MergedBlockMetaResponseCallback callback) { + // do nothing + } + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4a30f8de07..ab2deac20f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -113,6 +113,8 @@ public class TransportRequestHandler extends MessageHandler { processStreamRequest((StreamRequest) request); } else if (request instanceof UploadStream) { processStreamUpload((UploadStream) request); + } else if (request instanceof MergedBlockMetaRequest) { + processMergedBlockMetaRequest((MergedBlockMetaRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -260,6 +262,30 @@ public class TransportRequestHandler extends MessageHandler { } } + private void processMergedBlockMetaRequest(final MergedBlockMetaRequest req) { + try { + rpcHandler.getMergedBlockMetaReqHandler().receiveMergeBlockMetaReq(reverseClient, req, + new MergedBlockMetaResponseCallback() { + + @Override + public void onSuccess(int numChunks, ManagedBuffer buffer) { + logger.trace("Sending meta for request {} numChunks {}", req, numChunks); + respond(new MergedBlockMetaSuccess(req.requestId, numChunks, buffer)); + } + + @Override + public void onFailure(Throwable e) { + logger.trace("Failed to send meta for {}", req); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking receiveMergeBlockMetaReq() for appId {} shuffleId {} " + + "reduceId {}", req.appId, req.shuffleId, req.appId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 0a64471762..b3befb8baf 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; @@ -24,16 +25,19 @@ import io.netty.channel.Channel; import org.junit.Assert; import org.junit.Test; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.*; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportRequestHandler; public class TransportRequestHandlerSuite { @@ -109,4 +113,55 @@ public class TransportRequestHandlerSuite { streamManager.connectionTerminated(channel); Assert.assertEquals(0, streamManager.numStreamStates()); } + + @Test + public void handleMergedBlockMetaRequest() throws Exception { + RpcHandler.MergedBlockMetaReqHandler metaHandler = (client, request, callback) -> { + if (request.shuffleId != -1 && request.reduceId != -1) { + callback.onSuccess(2, mock(ManagedBuffer.class)); + } else { + callback.onFailure(new RuntimeException("empty block")); + } + }; + RpcHandler rpcHandler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) {} + + @Override + public StreamManager getStreamManager() { + return null; + } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return metaHandler; + } + }; + Channel channel = mock(Channel.class); + List> responseAndPromisePairs = new ArrayList<>(); + when(channel.writeAndFlush(any())).thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + TransportClient reverseClient = mock(TransportClient.class); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, + rpcHandler, 2L, null); + MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0); + requestHandler.handle(validMetaReq); + assertEquals(1, responseAndPromisePairs.size()); + assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess); + assertEquals(2, + ((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks()); + + MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1); + requestHandler.handle(invalidMetaReq); + assertEquals(2, responseAndPromisePairs.size()); + assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index b4032c4c3f..4de13f951d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -23,17 +23,20 @@ import java.nio.ByteBuffer; import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; +import org.mockito.ArgumentCaptor; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; @@ -167,4 +170,40 @@ public class TransportResponseHandlerSuite { verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); } + + @Test + public void handleSuccessfulMergedBlockMeta() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.addRpcRequest(13, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new MergedBlockMetaSuccess(22, 2, + new NioManagedBuffer(ByteBuffer.allocate(7)))); + assertEquals(1, handler.numOutstandingRequests()); + + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new MergedBlockMetaSuccess(13, 2, new NioManagedBuffer(resp))); + ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(NioManagedBuffer.class); + verify(callback, times(1)).onSuccess(eq(2), bufferCaptor.capture()); + assertEquals(resp, bufferCaptor.getValue().nioByteBuffer()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedMergedBlockMeta() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.addRpcRequest(51, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new RpcFailure(6, "failed")); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(51, "failed")); + verify(callback, times(1)).onFailure(any()); + assertEquals(0, handler.numOutstandingRequests()); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java new file mode 100644 index 0000000000..f4a055188c --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.nio.file.Files; +import java.util.List; + +import com.google.common.collect.Lists; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.Assert; +import org.junit.Test; +import org.roaringbitmap.RoaringBitmap; + +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * Test for {@link MergedBlockMetaSuccess}. + */ +public class MergedBlockMetaSuccessSuite { + + @Test + public void testMergedBlocksMetaEncodeDecode() throws Exception { + File chunkMetaFile = new File("target/mergedBlockMetaTest"); + Files.deleteIfExists(chunkMetaFile.toPath()); + RoaringBitmap chunk1 = new RoaringBitmap(); + chunk1.add(1); + chunk1.add(3); + RoaringBitmap chunk2 = new RoaringBitmap(); + chunk2.add(2); + chunk2.add(4); + RoaringBitmap[] expectedChunks = new RoaringBitmap[]{chunk1, chunk2}; + try (DataOutputStream metaOutput = new DataOutputStream(new FileOutputStream(chunkMetaFile))) { + for (int i = 0; i < expectedChunks.length; i++) { + expectedChunks[i].serialize(metaOutput); + } + } + TransportConf conf = mock(TransportConf.class); + when(conf.lazyFileDescriptor()).thenReturn(false); + long requestId = 1L; + MergedBlockMetaSuccess expectedMeta = new MergedBlockMetaSuccess(requestId, 2, + new FileSegmentManagedBuffer(conf, chunkMetaFile, 0, chunkMetaFile.length())); + + List out = Lists.newArrayList(); + ChannelHandlerContext context = mock(ChannelHandlerContext.class); + when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT); + + MessageEncoder.INSTANCE.encode(context, expectedMeta, out); + Assert.assertEquals(1, out.size()); + MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0); + + ByteArrayWritableChannel writableChannel = + new ByteArrayWritableChannel((int) msgWithHeader.count()); + while (msgWithHeader.transfered() < msgWithHeader.count()) { + msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered()); + } + ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData()); + messageBuf.readLong(); // frame length + MessageDecoder.INSTANCE.decode(mock(ChannelHandlerContext.class), messageBuf, out); + Assert.assertEquals(1, out.size()); + MergedBlockMetaSuccess decoded = (MergedBlockMetaSuccess) out.get(0); + Assert.assertEquals("merged block", expectedMeta.requestId, decoded.requestId); + Assert.assertEquals("num chunks", expectedMeta.getNumChunks(), decoded.getNumChunks()); + + ByteBuf responseBuf = Unpooled.wrappedBuffer(decoded.body().nioByteBuffer()); + RoaringBitmap[] responseBitmaps = new RoaringBitmap[expectedMeta.getNumChunks()]; + for (int i = 0; i < expectedMeta.getNumChunks(); i++) { + responseBitmaps[i] = Encoders.Bitmaps.decode(responseBuf); + } + Assert.assertEquals( + "num of roaring bitmaps", expectedMeta.getNumChunks(), responseBitmaps.length); + for (int i = 0; i < expectedMeta.getNumChunks(); i++) { + Assert.assertEquals("chunk bitmap " + i, expectedChunks[i], responseBitmaps[i]); + } + Files.delete(chunkMetaFile.toPath()); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index a6bdc13e93..238d26ee50 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -178,4 +178,24 @@ public abstract class BlockStoreClient implements Closeable { MergeFinalizerListener listener) { throw new UnsupportedOperationException(); } + + /** + * Get the meta information of a merged block from the remote shuffle service. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param listener the listener to receive chunk counts. + * + * @since 3.2.0 + */ + public void getMergedBlockMeta( + String host, + int port, + int shuffleId, + int reduceId, + MergedBlocksMetaListener listener) { + throw new UnsupportedOperationException(); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 09e29cafa9..c5f5834c0b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -17,12 +17,14 @@ package org.apache.spark.network.shuffle; +import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; import java.util.function.Function; import com.codahale.metrics.Gauge; @@ -31,14 +33,17 @@ import com.codahale.metrics.Metric; import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.codahale.metrics.Counter; +import com.google.common.collect.Sets; import com.google.common.annotations.VisibleForTesting; -import org.apache.spark.network.client.StreamCallbackWithID; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -54,8 +59,12 @@ import org.apache.spark.network.util.TransportConf; * Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk * is equivalent to one block. */ -public class ExternalBlockHandler extends RpcHandler { +public class ExternalBlockHandler extends RpcHandler + implements RpcHandler.MergedBlockMetaReqHandler { private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class); + private static final String SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger"; + private static final String SHUFFLE_BLOCK_ID = "shuffle"; + private static final String SHUFFLE_CHUNK_ID = "shuffleChunk"; @VisibleForTesting final ExternalShuffleBlockResolver blockManager; @@ -128,24 +137,23 @@ public class ExternalBlockHandler extends RpcHandler { BlockTransferMessage msgObj, TransportClient client, RpcResponseCallback callback) { - if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) { + if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) { final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time(); try { int numBlockIds; long streamId; - if (msgObj instanceof FetchShuffleBlocks) { - FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj; + if (msgObj instanceof AbstractFetchShuffleBlocks) { + AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj; checkAuth(client, msg.appId); - numBlockIds = 0; - if (msg.batchFetchEnabled) { - numBlockIds = msg.mapIds.length; + numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks(); + Iterator iterator; + if (msgObj instanceof FetchShuffleBlocks) { + iterator = new ShuffleManagedBufferIterator((FetchShuffleBlocks)msgObj); } else { - for (int[] ids: msg.reduceIds) { - numBlockIds += ids.length; - } + iterator = new ShuffleChunkManagedBufferIterator((FetchShuffleBlockChunks) msgObj); } - streamId = streamManager.registerStream(client.getClientId(), - new ShuffleManagedBufferIterator(msg), client.getChannel()); + streamId = streamManager.registerStream(client.getClientId(), iterator, + client.getChannel()); } else { // For the compatibility with the old version, still keep the support for OpenBlocks. OpenBlocks msg = (OpenBlocks) msgObj; @@ -189,9 +197,14 @@ public class ExternalBlockHandler extends RpcHandler { } else if (msgObj instanceof GetLocalDirsForExecutors) { GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj; checkAuth(client, msg.appId); - Map localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds); + Set execIdsForBlockResolver = Sets.newHashSet(msg.execIds); + boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER); + Map localDirs = blockManager.getLocalDirs(msg.appId, + execIdsForBlockResolver); + if (fetchMergedBlockDirs) { + localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId)); + } callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer()); - } else if (msgObj instanceof FinalizeShuffleMerge) { final Timer.Context responseDelayContext = metrics.finalizeShuffleMergeLatencyMillis.time(); @@ -211,6 +224,32 @@ public class ExternalBlockHandler extends RpcHandler { } } + @Override + public void receiveMergeBlockMetaReq( + TransportClient client, + MergedBlockMetaRequest metaRequest, + MergedBlockMetaResponseCallback callback) { + final Timer.Context responseDelayContext = metrics.fetchMergedBlocksMetaLatencyMillis.time(); + try { + checkAuth(client, metaRequest.appId); + MergedBlockMeta mergedMeta = + mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId, + metaRequest.reduceId); + logger.debug( + "Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ", + metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId, + mergedMeta.getNumChunks()); + callback.onSuccess(mergedMeta.getNumChunks(), mergedMeta.getChunksBitmapBuffer()); + } finally { + responseDelayContext.stop(); + } + } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return this; + } + @Override public void exceptionCaught(Throwable cause, TransportClient client) { metrics.caughtExceptions.inc(); @@ -262,6 +301,8 @@ public class ExternalBlockHandler extends RpcHandler { private final Timer openBlockRequestLatencyMillis = new Timer(); // Time latency for executor registration latency in ms private final Timer registerExecutorRequestLatencyMillis = new Timer(); + // Time latency for processing fetch merged blocks meta request latency in ms + private final Timer fetchMergedBlocksMetaLatencyMillis = new Timer(); // Time latency for processing finalize shuffle merge request latency in ms private final Timer finalizeShuffleMergeLatencyMillis = new Timer(); // Block transfer rate in byte per second @@ -275,6 +316,7 @@ public class ExternalBlockHandler extends RpcHandler { allMetrics = new HashMap<>(); allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis); allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis); + allMetrics.put("fetchMergedBlocksMetaLatencyMillis", fetchMergedBlocksMetaLatencyMillis); allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis); allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); allMetrics.put("registeredExecutorsSize", @@ -294,18 +336,26 @@ public class ExternalBlockHandler extends RpcHandler { private int index = 0; private final Function blockDataForIndexFn; private final int size; + private boolean requestForMergedBlockChunks; ManagedBufferIterator(OpenBlocks msg) { String appId = msg.appId; String execId = msg.execId; String[] blockIds = msg.blockIds; String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) { + if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_BLOCK_ID)) { final int shuffleId = Integer.parseInt(blockId0Parts[1]); final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); size = mapIdAndReduceIds.length; blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + } else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) { + requestForMergedBlockChunks = true; + final int shuffleId = Integer.parseInt(blockId0Parts[1]); + final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); + size = reduceIdAndChunkIds.length; + blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId, + reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]); } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) { final int[] rddAndSplitIds = rddAndSplitIds(blockIds); size = rddAndSplitIds.length; @@ -330,20 +380,26 @@ public class ExternalBlockHandler extends RpcHandler { } private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { - final int[] mapIdAndReduceIds = new int[2 * blockIds.length]; + // For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds. + // For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds. + final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + if (blockIdParts.length != 4 + || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID)) + || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); } if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockIds[i]); } - mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); - mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + // For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId. + primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]); + // For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId. + primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); } - return mapIdAndReduceIds; + return primaryIdAndSecondaryIds; } @Override @@ -413,6 +469,47 @@ public class ExternalBlockHandler extends RpcHandler { } } + private class ShuffleChunkManagedBufferIterator implements Iterator { + + private int reduceIdx = 0; + private int chunkIdx = 0; + + private final String appId; + private final int shuffleId; + private final int[] reduceIds; + private final int[][] chunkIds; + + ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) { + appId = msg.appId; + shuffleId = msg.shuffleId; + reduceIds = msg.reduceIds; + chunkIds = msg.chunkIds; + // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks + // must have non-empty reduceIds and chunkIds, see the checking logic in + // OneForOneBlockFetcher. + assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length); + } + + @Override + public boolean hasNext() { + return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length; + } + + @Override + public ManagedBuffer next() { + ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData( + appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx])); + if (chunkIdx < chunkIds[reduceIdx].length - 1) { + chunkIdx += 1; + } else { + chunkIdx = 0; + reduceIdx += 1; + } + metrics.blockTransferRateBytes.mark(block.size()); + return block; + } + } + /** * Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle * is not enabled. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 56c06e640a..f44140b124 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; @@ -187,6 +188,37 @@ public class ExternalBlockStoreClient extends BlockStoreClient { } } + @Override + public void getMergedBlockMeta( + String host, + int port, + int shuffleId, + int reduceId, + MergedBlocksMetaListener listener) { + checkInit(); + logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port, + shuffleId, reduceId); + try { + TransportClient client = clientFactory.createClient(host, port); + client.sendMergedBlockMetaReq(appId, shuffleId, reduceId, + new MergedBlockMetaResponseCallback() { + @Override + public void onSuccess(int numChunks, ManagedBuffer buffer) { + logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}", + shuffleId, reduceId); + listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer)); + } + + @Override + public void onFailure(Throwable e) { + listener.onFailure(shuffleId, reduceId, e); + } + }); + } catch (Exception e) { + listener.onFailure(shuffleId, reduceId, e); + } + } + @Override public MetricSet shuffleMetrics() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index a095bf2723..493edd2b34 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -361,8 +361,8 @@ public class ExternalShuffleBlockResolver { return numRemovedBlocks; } - public Map getLocalDirs(String appId, String[] execIds) { - return Arrays.stream(execIds) + public Map getLocalDirs(String appId, Set execIds) { + return execIds.stream() .map(exec -> { ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec)); if (info == null) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java new file mode 100644 index 0000000000..0e277d3303 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlocksMetaListener.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.EventListener; + +/** + * Listener for receiving success or failure events when fetching meta of merged blocks. + * + * @since 3.2.0 + */ +public interface MergedBlocksMetaListener extends EventListener { + + /** + * Called after successfully receiving the meta of a merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param meta contains meta information of a merged block. + */ + void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta); + + /** + * Called when there is an exception while fetching the meta of a merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param exception exception getting chunk counts. + */ + void onFailure(int shuffleId, int reduceId, Throwable exception); +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 0b7eaa6225..2bf16b0097 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; +import java.util.Set; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; @@ -34,8 +35,10 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.shuffle.protocol.AbstractFetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.TransportConf; @@ -51,6 +54,8 @@ import org.apache.spark.network.util.TransportConf; */ public class OneForOneBlockFetcher { private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); + private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_"; + private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_"; private final TransportClient client; private final BlockTransferMessage message; @@ -88,74 +93,113 @@ public class OneForOneBlockFetcher { if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { + if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we + // check if all the block ids are shuffle chunk Ids. + return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by - * analyzing the pass in blockIds. + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by + * analyzing the passed in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, + id -> new BlocksInfo()); + blocksInfoByPrimaryId.blockIds.add(blockId); + // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a + // shuffleChunk block, then blockIdParts[3] = chunkId + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { + // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block. // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + // blockIdParts[4] is the end reduce id for the batch range + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) { + secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids); - // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks - // because the shuffle data's return order should match the `blockIds`'s order to ensure - // blockId and data match. - for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { - this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ + // FetchShuffleBlockChunks because the shuffle data's return order should match the + // `blockIds`'s order to ensure blockId and data match. + for (String blockId : blocksInfo.blockIds) { + this.blockIds[blockIdIndex++] = blockId; } } assert(blockIdIndex == this.blockIds.length); - - return new FetchShuffleBlocks( - appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); + Set primaryIds = primaryIdToBlocksInfo.keySet(); + if (!areMergedChunks) { + long[] mapIds = Longs.toArray(primaryIds); + return new FetchShuffleBlocks( + appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled); + } else { + int[] reduceIds = Ints.toArray(primaryIds); + return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray); + } } /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */ @@ -163,7 +207,17 @@ public class OneForOneBlockFetcher { String[] blockIdParts = blockId.split("_"); // For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId. // For single block id, the format contains shuffleId, mapId, educeId. - if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) { + // For single block chunk id, the format contains shuffleId, reduceId, chunkId. + if (blockIdParts.length < 4 || blockIdParts.length > 5) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + if (blockIdParts.length == 4 && + !(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) { throw new IllegalArgumentException( "Unexpected shuffle block id format: " + blockId); } @@ -173,11 +227,15 @@ public class OneForOneBlockFetcher { /** The reduceIds and blocks in a single mapId */ private class BlocksInfo { - final ArrayList reduceIds; + /** + * For {@link FetchShuffleBlocks} message, the ids are reduceIds. + * For {@link FetchShuffleBlockChunks} message, the ids are chunkIds. + */ + final ArrayList ids; final ArrayList blockIds; BlocksInfo() { - this.reduceIds = new ArrayList<>(); + this.ids = new ArrayList<>(); this.blockIds = new ArrayList<>(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java new file mode 100644 index 0000000000..0fca27cf26 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** + * Base class for fetch shuffle blocks and chunks. + * + * @since 3.2.0 + */ +public abstract class AbstractFetchShuffleBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + + protected AbstractFetchShuffleBlocks( + String appId, + String execId, + int shuffleId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + } + + public ToStringBuilder toStringHelper() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("shuffleId", shuffleId); + } + + /** + * Returns number of blocks in the request. + */ + public abstract int getNumBlocks(); + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AbstractFetchShuffleBlocks that = (AbstractFetchShuffleBlocks) o; + return shuffleId == that.shuffleId + && Objects.equal(appId, that.appId) && Objects.equal(execId, that.execId); + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + shuffleId; + return result; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 4; /* encoded length of shuffleId */ + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 7f50581249..a55a6cf7ed 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -48,7 +48,8 @@ public abstract class BlockTransferMessage implements Encodable { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), - PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14); + PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), + FETCH_SHUFFLE_BLOCK_CHUNKS(15); private final byte id; @@ -82,6 +83,7 @@ public abstract class BlockTransferMessage implements Encodable { case 12: return PushBlockStream.decode(buf); case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); + case 15: return FetchShuffleBlockChunks.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java new file mode 100644 index 0000000000..27345dd8e7 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + + +/** + * Request to read a set of block chunks. Returns {@link StreamHandle}. + * + * @since 3.2.0 + */ +public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { + // The length of reduceIds must equal to chunkIds.size(). + public final int[] reduceIds; + // The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds. + public final int[][] chunkIds; + + public FetchShuffleBlockChunks( + String appId, + String execId, + int shuffleId, + int[] reduceIds, + int[][] chunkIds) { + super(appId, execId, shuffleId); + this.reduceIds = reduceIds; + this.chunkIds = chunkIds; + assert(reduceIds.length == chunkIds.length); + } + + @Override + protected Type type() { return Type.FETCH_SHUFFLE_BLOCK_CHUNKS; } + + @Override + public String toString() { + return toStringHelper() + .append("reduceIds", Arrays.toString(reduceIds)) + .append("chunkIds", Arrays.deepToString(chunkIds)) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o; + if (!super.equals(that)) return false; + if (!Arrays.equals(reduceIds, that.reduceIds)) return false; + return Arrays.deepEquals(chunkIds, that.chunkIds); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Arrays.hashCode(reduceIds); + result = 31 * result + Arrays.deepHashCode(chunkIds); + return result; + } + + @Override + public int encodedLength() { + int encodedLengthOfChunkIds = 0; + for (int[] ids: chunkIds) { + encodedLengthOfChunkIds += Encoders.IntArrays.encodedLength(ids); + } + return super.encodedLength() + + Encoders.IntArrays.encodedLength(reduceIds) + + 4 /* encoded length of chunkIds.size() */ + + encodedLengthOfChunkIds; + } + + @Override + public void encode(ByteBuf buf) { + super.encode(buf); + Encoders.IntArrays.encode(buf, reduceIds); + // Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the + // interest of forward compatibility. + buf.writeInt(chunkIds.length); + for (int[] ids: chunkIds) { + Encoders.IntArrays.encode(buf, ids); + } + } + + @Override + public int getNumBlocks() { + int numBlocks = 0; + for (int[] ids : chunkIds) { + numBlocks += ids.length; + } + return numBlocks; + } + + public static FetchShuffleBlockChunks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int[] reduceIds = Encoders.IntArrays.decode(buf); + int chunkIdsLen = buf.readInt(); + int[][] chunkIds = new int[chunkIdsLen][]; + for (int i = 0; i < chunkIdsLen; i++) { + chunkIds[i] = Encoders.IntArrays.decode(buf); + } + return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 98057d58f7..68550a2fba 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -20,8 +20,6 @@ package org.apache.spark.network.shuffle.protocol; import java.util.Arrays; import io.netty.buffer.ByteBuf; -import org.apache.commons.lang3.builder.ToStringBuilder; -import org.apache.commons.lang3.builder.ToStringStyle; import org.apache.spark.network.protocol.Encoders; @@ -29,10 +27,7 @@ import org.apache.spark.network.protocol.Encoders; import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ -public class FetchShuffleBlocks extends BlockTransferMessage { - public final String appId; - public final String execId; - public final int shuffleId; +public class FetchShuffleBlocks extends AbstractFetchShuffleBlocks { // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. public final long[] mapIds; @@ -50,9 +45,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { long[] mapIds, int[][] reduceIds, boolean batchFetchEnabled) { - this.appId = appId; - this.execId = execId; - this.shuffleId = shuffleId; + super(appId, execId, shuffleId); this.mapIds = mapIds; this.reduceIds = reduceIds; assert(mapIds.length == reduceIds.length); @@ -69,10 +62,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { @Override public String toString() { - return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) - .append("appId", appId) - .append("execId", execId) - .append("shuffleId", shuffleId) + return toStringHelper() .append("mapIds", Arrays.toString(mapIds)) .append("reduceIds", Arrays.deepToString(reduceIds)) .append("batchFetchEnabled", batchFetchEnabled) @@ -85,35 +75,40 @@ public class FetchShuffleBlocks extends BlockTransferMessage { if (o == null || getClass() != o.getClass()) return false; FetchShuffleBlocks that = (FetchShuffleBlocks) o; - - if (shuffleId != that.shuffleId) return false; + if (!super.equals(that)) return false; if (batchFetchEnabled != that.batchFetchEnabled) return false; - if (!appId.equals(that.appId)) return false; - if (!execId.equals(that.execId)) return false; if (!Arrays.equals(mapIds, that.mapIds)) return false; return Arrays.deepEquals(reduceIds, that.reduceIds); } @Override public int hashCode() { - int result = appId.hashCode(); - result = 31 * result + execId.hashCode(); - result = 31 * result + shuffleId; + int result = super.hashCode(); result = 31 * result + Arrays.hashCode(mapIds); result = 31 * result + Arrays.deepHashCode(reduceIds); result = 31 * result + (batchFetchEnabled ? 1 : 0); return result; } + @Override + public int getNumBlocks() { + if (batchFetchEnabled) { + return mapIds.length; + } + int numBlocks = 0; + for (int[] ids : reduceIds) { + numBlocks += ids.length; + } + return numBlocks; + } + @Override public int encodedLength() { int encodedLengthOfReduceIds = 0; for (int[] ids: reduceIds) { encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids); } - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + 4 /* encoded length of shuffleId */ + return super.encodedLength() + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds @@ -122,9 +117,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { @Override public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - buf.writeInt(shuffleId); + super.encode(buf); Encoders.LongArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index f06e7cb047..bad61d30d7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -34,13 +34,16 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -258,4 +261,113 @@ public class ExternalBlockHandlerSuite { .get("finalizeShuffleMergeLatencyMillis"); assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount()); } + + @Test + public void testFetchMergedBlocksMeta() { + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn( + new MergedBlockMeta(1, mock(ManagedBuffer.class))); + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn( + new MergedBlockMeta(3, mock(ManagedBuffer.class))); + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn( + new MergedBlockMeta(5, mock(ManagedBuffer.class))); + + int[] expectedCount = new int[]{1, 3, 5}; + String appId = "app0"; + long requestId = 0L; + for (int reduceId = 0; reduceId < 3; reduceId++) { + MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.getMergedBlockMetaReqHandler() + .receiveMergeBlockMetaReq(client, req, callback); + verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId); + + ArgumentCaptor numChunksResponse = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor chunkBitmapResponse = + ArgumentCaptor.forClass(ManagedBuffer.class); + verify(callback, times(1)).onSuccess(numChunksResponse.capture(), + chunkBitmapResponse.capture()); + assertEquals("num chunks in merged block " + reduceId, expectedCount[reduceId], + numChunksResponse.getValue().intValue()); + assertNotNull("chunks bitmap buffer " + reduceId, chunkBitmapResponse.getValue()); + } + } + + @Test + public void testOpenBlocksWithShuffleChunks() { + verifyBlockChunkFetches(true); + } + + @Test + public void testFetchShuffleChunks() { + verifyBlockChunkFetches(false); + } + + private void verifyBlockChunkFetches(boolean useOpenBlocks) { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + ByteBuffer buffer; + if (useOpenBlocks) { + OpenBlocks openBlocks = + new OpenBlocks("app0", "exec1", + new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0", + "shuffleChunk_0_1_1"}); + buffer = openBlocks.toByteBuffer(); + } else { + FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks( + "app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}}); + buffer = fetchChunks.toByteBuffer(); + } + ManagedBuffer[][] buffers = new ManagedBuffer[][] { + { + new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + }, + { + new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + } + }; + for (int reduceId = 0; reduceId < 2; reduceId++) { + for (int chunkId = 0; chunkId < 2; chunkId++) { + when(mergedShuffleManager.getMergedBlockData( + "app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]); + } + } + handler.receive(client, buffer, callback); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure(any()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); + assertEquals(4, handle.numChunks); + + @SuppressWarnings("unchecked") + ArgumentCaptor> stream = (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(any(), stream.capture(), any()); + Iterator bufferIter = stream.getValue(); + for (int reduceId = 0; reduceId < 2; reduceId++) { + for (int chunkId = 0; chunkId < 2; chunkId++) { + assertEquals(buffers[reduceId][chunkId], bufferIter.next()); + } + } + assertFalse(bufferIter.hasNext()); + verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt()); + verify(blockResolver, never()).getBlockData( + anyString(), anyString(), anyInt(), anyInt(), anyInt()); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1); + + // Verify open block request latency metrics + Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler) + .getAllMetrics() + .getMetrics() + .get("openBlockRequestLatencyMillis"); + assertEquals(1, openBlockRequestLatencyMillis.getCount()); + // Verify block transfer metrics + Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler) + .getAllMetrics() + .getMetrics() + .get("blockTransferRateBytes"); + assertEquals(24, blockTransferRateBytes.getCount()); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index a7eb59d366..c4967eab31 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -27,8 +27,7 @@ import com.google.common.collect.Maps; import io.netty.buffer.Unpooled; import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; @@ -46,6 +45,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.MapConfigProvider; @@ -243,6 +243,57 @@ public class OneForOneBlockFetcherSuite { } } + @Test + public void testShuffleBlockChunksFetch() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks(blocks, blockIds, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0, 1, 2 }}), conf); + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i, + blocks.get("shuffleChunk_0_0_" + i)); + } + } + + @Test + public void testShuffleBlockChunkFetchFailure() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffleChunk_0_0_1", null); + blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks(blocks, blockIds, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}), + conf); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0", + blocks.get("shuffleChunk_0_0_0")); + verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any()); + verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2", + blocks.get("shuffleChunk_0_0_2")); + } + + @Test + public void testInvalidShuffleBlockIds() { + assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), + new String[]{"shuffle_0_0"}, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0 }}), conf)); + assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), + new String[]{"shuffleChunk_0_0_0_0_0"}, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0 }}), conf)); + assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(), + new String[]{"shuffleChunk_0_0_0_0"}, + new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 }, + new int[][] {{ 0 }}), conf)); + } + /** * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which * simply returns the given (BlockId, Block) pairs. diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java new file mode 100644 index 0000000000..91f319ded4 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class FetchShuffleBlockChunksSuite { + + @Test + public void testFetchShuffleBlockChunksEncodeDecode() { + FetchShuffleBlockChunks shuffleBlockChunks = + new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}}); + Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks()); + int len = shuffleBlockChunks.encodedLength(); + Assert.assertEquals(45, len); + ByteBuf buf = Unpooled.buffer(len); + shuffleBlockChunks.encode(buf); + + FetchShuffleBlockChunks decoded = FetchShuffleBlockChunks.decode(buf); + assertEquals(shuffleBlockChunks, decoded); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java new file mode 100644 index 0000000000..a1681f58e7 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class FetchShuffleBlocksSuite { + + @Test + public void testFetchShuffleBlockEncodeDecode() { + FetchShuffleBlocks fetchShuffleBlocks = + new FetchShuffleBlocks("app0", "exec1", 0, new long[] {0}, new int[][] {{0, 1}}, false); + Assert.assertEquals(2, fetchShuffleBlocks.getNumBlocks()); + int len = fetchShuffleBlocks.encodedLength(); + Assert.assertEquals(50, len); + ByteBuf buf = Unpooled.buffer(len); + fetchShuffleBlocks.encode(buf); + + FetchShuffleBlocks decoded = FetchShuffleBlocks.decode(buf); + assertEquals(fetchShuffleBlocks, decoded); + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala index ea4d252f0d..4ce46156c0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala @@ -62,7 +62,8 @@ class ExternalShuffleServiceMetricsSuite extends SparkFunSuite { "registerExecutorRequestLatencyMillis", "shuffle-server.usedDirectMemory", "shuffle-server.usedHeapMemory", - "finalizeShuffleMergeLatencyMillis") + "finalizeShuffleMergeLatencyMillis", + "fetchMergedBlocksMetaLatencyMillis") ) } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index 9239d891aa..d866fac726 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -40,7 +40,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { val allMetrics = Set( "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", - "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis") + "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis", + "fetchMergedBlocksMetaLatencyMillis") metrics.getMetrics.keySet().asScala should be (allMetrics) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index d6d1715223..afe85b3d40 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -412,7 +412,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd "registerExecutorRequestLatencyMillis", "finalizeShuffleMergeLatencyMillis", "shuffle-server.usedDirectMemory", - "shuffle-server.usedHeapMemory" + "shuffle-server.usedHeapMemory", + "fetchMergedBlocksMetaLatencyMillis" )) }