[SPARK-35671][SHUFFLE][CORE] Add support in the ESS to serve merged shuffle block meta and data to executors

### What changes were proposed in this pull request?
This adds support in the ESS to serve merged shuffle block meta and data requests to executors.
This change is needed for fetching remote merged shuffle data from the remote shuffle services. This is part of push-based shuffle SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).

This change introduces new messages between clients and the external shuffle service:

1. `MergedBlockMetaRequest`: The client sends this to external shuffle to get the meta information for a merged block. The response to this is one of these :
  - `MergedBlockMetaSuccess` : contains request id, number of chunks, and a `ManagedBuffer` which is a `FileSegmentBuffer` backed by the merged block meta file.
  - `RpcFailure`: this is sent back to client in case of failure. This is an existing message.

2. `FetchShuffleBlockChunks`: This is similar to `FetchShuffleBlocks` message but it is to fetch merged shuffle chunks instead of blocks.

### Why are the changes needed?
These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added unit tests.
The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
We have already verified the functionality and the improved performance as documented in the SPIP doc.

Lead-authored-by: Chandni Singh chsinghlinkedin.com
Co-authored-by: Min Shen mshenlinkedin.com

Closes #32811 from otterc/SPARK-35671.

Lead-authored-by: Chandni Singh <singh.chandni@gmail.com>
Co-authored-by: Min Shen <mshen@linkedin.com>
Co-authored-by: Chandni Singh <chsingh@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
This commit is contained in:
Chandni Singh 2021-06-20 17:22:37 -05:00 committed by Mridul Muralidharan
parent af20474c67
commit 8ce1e344e5
33 changed files with 1398 additions and 102 deletions

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.client;
/**
* A basic callback. This is extended by {@link RpcResponseCallback} and
* {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests
* can be handled in {@link TransportResponseHandler} a similar way.
*
* @since 3.2.0
*/
public interface BaseResponseCallback {
/** Exception either propagated from server or raised on client side. */
void onFailure(Throwable e);
}

View file

@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.client;
import org.apache.spark.network.buffer.ManagedBuffer;
/**
* Callback for the result of a single
* {@link org.apache.spark.network.protocol.MergedBlockMetaRequest}.
*
* @since 3.2.0
*/
public interface MergedBlockMetaResponseCallback extends BaseResponseCallback {
/**
* Called upon receipt of a particular merged block meta.
*
* The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
* call returns. You must therefore either retain() the buffer or copy its contents before
* returning.
*
* @param numChunks number of merged chunks in the merged block
* @param buffer the buffer contains an array of roaring bitmaps. The i-th roaring bitmap
* contains the mapIds that were merged to the i-th merged chunk.
*/
void onSuccess(int numChunks, ManagedBuffer buffer);
}

View file

@ -23,7 +23,7 @@ import java.nio.ByteBuffer;
* Callback for the result of a single RPC. This will be invoked once with either success or
* failure.
*/
public interface RpcResponseCallback {
public interface RpcResponseCallback extends BaseResponseCallback {
/**
* Successful serialized result from server.
*
@ -31,7 +31,4 @@ public interface RpcResponseCallback {
* Please copy the content of `response` if you want to use it after `onSuccess` returns.
*/
void onSuccess(ByteBuffer response);
/** Exception either propagated from server or raised on client side. */
void onFailure(Throwable e);
}

View file

@ -200,6 +200,31 @@ public class TransportClient implements Closeable {
return requestId;
}
/**
* Sends a MergedBlockMetaRequest message to the server. The response of this message is
* either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}.
*
* @param appId applicationId.
* @param shuffleId shuffle id.
* @param reduceId reduce id.
* @param callback callback the handle the reply.
*/
public void sendMergedBlockMetaReq(
String appId,
int shuffleId,
int reduceId,
MergedBlockMetaResponseCallback callback) {
long requestId = requestId();
if (logger.isTraceEnabled()) {
logger.trace(
"Sending RPC {} to fetch merged block meta to {}", requestId, getRemoteAddress(channel));
}
handler.addRpcRequest(requestId, callback);
RpcChannelListener listener = new RpcChannelListener(requestId, callback);
channel.writeAndFlush(
new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener);
}
/**
* Send data to the remote end as a stream. This differs from stream() in that this is a request
* to *send* data to the remote end, not to receive it from the remote.
@ -349,9 +374,9 @@ public class TransportClient implements Closeable {
private class RpcChannelListener extends StdChannelListener {
final long rpcRequestId;
final RpcResponseCallback callback;
final BaseResponseCallback callback;
RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
RpcChannelListener(long rpcRequestId, BaseResponseCallback callback) {
super("RPC " + rpcRequestId);
this.rpcRequestId = rpcRequestId;
this.callback = callback;

View file

@ -33,6 +33,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.MergedBlockMetaSuccess;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
@ -56,7 +57,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
private final Map<Long, RpcResponseCallback> outstandingRpcs;
private final Map<Long, BaseResponseCallback> outstandingRpcs;
private final Queue<Pair<String, StreamCallback>> streamCallbacks;
private volatile boolean streamActive;
@ -81,7 +82,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingFetches.remove(streamChunkId);
}
public void addRpcRequest(long requestId, RpcResponseCallback callback) {
public void addRpcRequest(long requestId, BaseResponseCallback callback) {
updateTimeOfLastRequest();
outstandingRpcs.put(requestId, callback);
}
@ -112,7 +113,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
}
}
for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
for (Map.Entry<Long, BaseResponseCallback> entry : outstandingRpcs.entrySet()) {
try {
entry.getValue().onFailure(cause);
} catch (Exception e) {
@ -184,7 +185,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
} else if (message instanceof RpcResponse) {
RpcResponse resp = (RpcResponse) message;
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
resp.requestId, getRemoteAddress(channel), resp.body().size());
@ -199,7 +200,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
} else if (message instanceof RpcFailure) {
RpcFailure resp = (RpcFailure) message;
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
BaseResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
resp.requestId, getRemoteAddress(channel), resp.errorString);
@ -207,6 +208,22 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingRpcs.remove(resp.requestId);
listener.onFailure(new RuntimeException(resp.errorString));
}
} else if (message instanceof MergedBlockMetaSuccess) {
MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message;
try {
MergedBlockMetaResponseCallback listener =
(MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn(
"Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not"
+ " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size());
} else {
outstandingRpcs.remove(resp.requestId);
listener.onSuccess(resp.getNumChunks(), resp.body());
}
} finally {
resp.body().release();
}
} else if (message instanceof StreamResponse) {
StreamResponse resp = (StreamResponse) message;
Pair<String, StreamCallback> entry = streamCallbacks.poll();

View file

@ -138,4 +138,9 @@ class AuthRpcHandler extends AbstractAuthRpcHandler {
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
return true;
}
@Override
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
return saslHandler.getMergedBlockMetaReqHandler();
}
}

View file

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
/**
* Request to find the meta information for the specified merged block. The meta information
* contains the number of chunks in the merged blocks and the maps ids in each chunk.
*
* @since 3.2.0
*/
public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage {
public final long requestId;
public final String appId;
public final int shuffleId;
public final int reduceId;
public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) {
super(null, false);
this.requestId = requestId;
this.appId = appId;
this.shuffleId = shuffleId;
this.reduceId = reduceId;
}
@Override
public Type type() {
return Type.MergedBlockMetaRequest;
}
@Override
public int encodedLength() {
return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
Encoders.Strings.encode(buf, appId);
buf.writeInt(shuffleId);
buf.writeInt(reduceId);
}
public static MergedBlockMetaRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
String appId = Encoders.Strings.decode(buf);
int shuffleId = buf.readInt();
int reduceId = buf.readInt();
return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
}
@Override
public int hashCode() {
return Objects.hashCode(requestId, appId, shuffleId, reduceId);
}
@Override
public boolean equals(Object other) {
if (other instanceof MergedBlockMetaRequest) {
MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId
&& Objects.equal(appId, o.appId);
}
return false;
}
@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
.append("requestId", requestId)
.append("appId", appId)
.append("shuffleId", shuffleId)
.append("reduceId", reduceId)
.toString();
}
}

View file

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;
/**
* Response to {@link MergedBlockMetaRequest} request.
* Note that the server-side encoding of this messages does NOT include the buffer itself.
*
* @since 3.2.0
*/
public class MergedBlockMetaSuccess extends AbstractResponseMessage {
public final long requestId;
public final int numChunks;
public MergedBlockMetaSuccess(
long requestId,
int numChunks,
ManagedBuffer chunkBitmapsBuffer) {
super(chunkBitmapsBuffer, true);
this.requestId = requestId;
this.numChunks = numChunks;
}
@Override
public Type type() {
return Type.MergedBlockMetaSuccess;
}
@Override
public int hashCode() {
return Objects.hashCode(requestId, numChunks);
}
@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
.append("requestId", requestId).append("numChunks", numChunks).toString();
}
@Override
public int encodedLength() {
return 8 + 4;
}
/** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
buf.writeInt(numChunks);
}
public int getNumChunks() {
return numChunks;
}
/** Decoding uses the given ByteBuf as our data, and will retain() it. */
public static MergedBlockMetaSuccess decode(ByteBuf buf) {
long requestId = buf.readLong();
int numChunks = buf.readInt();
buf.retain();
NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
return new MergedBlockMetaSuccess(requestId, numChunks, managedBuf);
}
@Override
public ResponseMessage createFailureResponse(String error) {
return new RpcFailure(requestId, error);
}
}

View file

@ -37,7 +37,8 @@ public interface Message extends Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), UploadStream(10), User(-1);
OneWayMessage(9), UploadStream(10), MergedBlockMetaRequest(11), MergedBlockMetaSuccess(12),
User(-1);
private final byte id;
@ -66,6 +67,8 @@ public interface Message extends Encodable {
case 8: return StreamFailure;
case 9: return OneWayMessage;
case 10: return UploadStream;
case 11: return MergedBlockMetaRequest;
case 12: return MergedBlockMetaSuccess;
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}

View file

@ -83,6 +83,12 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
case UploadStream:
return UploadStream.decode(in);
case MergedBlockMetaRequest:
return MergedBlockMetaRequest.decode(in);
case MergedBlockMetaSuccess:
return MergedBlockMetaSuccess.decode(in);
default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}

View file

@ -104,4 +104,9 @@ public abstract class AbstractAuthRpcHandler extends RpcHandler {
public boolean isAuthenticated() {
return isAuthenticated;
}
@Override
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
return delegate.getMergedBlockMetaReqHandler();
}
}

View file

@ -22,9 +22,11 @@ import java.nio.ByteBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
/**
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
@ -32,6 +34,8 @@ import org.apache.spark.network.client.TransportClient;
public abstract class RpcHandler {
private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
private static final MergedBlockMetaReqHandler NOOP_MERGED_BLOCK_META_REQ_HANDLER =
new NoopMergedBlockMetaReqHandler();
/**
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
@ -100,6 +104,10 @@ public abstract class RpcHandler {
receive(client, message, ONE_WAY_CALLBACK);
}
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
return NOOP_MERGED_BLOCK_META_REQ_HANDLER;
}
/**
* Invoked when the channel associated with the given client is active.
*/
@ -129,4 +137,40 @@ public abstract class RpcHandler {
}
/**
* Handler for {@link MergedBlockMetaRequest}.
*
* @since 3.2.0
*/
public interface MergedBlockMetaReqHandler {
/**
* Receive a {@link MergedBlockMetaRequest}.
*
* @param client A channel client which enables the handler to make requests back to the sender
* of this RPC.
* @param mergedBlockMetaRequest Request for merged block meta.
* @param callback Callback which should be invoked exactly once upon success or failure.
*/
void receiveMergeBlockMetaReq(
TransportClient client,
MergedBlockMetaRequest mergedBlockMetaRequest,
MergedBlockMetaResponseCallback callback);
}
/**
* A Noop implementation of {@link MergedBlockMetaReqHandler}. This Noop implementation is used
* by all the RPC handlers which don't eventually delegate the {@link MergedBlockMetaRequest} to
* ExternalBlockHandler in the network-shuffle module.
*
* @since 3.2.0
*/
private static class NoopMergedBlockMetaReqHandler implements MergedBlockMetaReqHandler {
@Override
public void receiveMergeBlockMetaReq(TransportClient client,
MergedBlockMetaRequest mergedBlockMetaRequest, MergedBlockMetaResponseCallback callback) {
// do nothing
}
}
}

View file

@ -113,6 +113,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
processStreamRequest((StreamRequest) request);
} else if (request instanceof UploadStream) {
processStreamUpload((UploadStream) request);
} else if (request instanceof MergedBlockMetaRequest) {
processMergedBlockMetaRequest((MergedBlockMetaRequest) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
@ -260,6 +262,30 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
}
}
private void processMergedBlockMetaRequest(final MergedBlockMetaRequest req) {
try {
rpcHandler.getMergedBlockMetaReqHandler().receiveMergeBlockMetaReq(reverseClient, req,
new MergedBlockMetaResponseCallback() {
@Override
public void onSuccess(int numChunks, ManagedBuffer buffer) {
logger.trace("Sending meta for request {} numChunks {}", req, numChunks);
respond(new MergedBlockMetaSuccess(req.requestId, numChunks, buffer));
}
@Override
public void onFailure(Throwable e) {
logger.trace("Failed to send meta for {}", req);
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
}
});
} catch (Exception e) {
logger.error("Error while invoking receiveMergeBlockMetaReq() for appId {} shuffleId {} "
+ "reduceId {}", req.appId, req.shuffleId, req.appId, e);
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
}
}
/**
* Responds to a single message with some Encodable object. If a failure occurs while sending,
* it will be logged and the channel closed.

View file

@ -17,6 +17,7 @@
package org.apache.spark.network;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
@ -24,16 +25,19 @@ import io.netty.channel.Channel;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.*;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportRequestHandler;
public class TransportRequestHandlerSuite {
@ -109,4 +113,55 @@ public class TransportRequestHandlerSuite {
streamManager.connectionTerminated(channel);
Assert.assertEquals(0, streamManager.numStreamStates());
}
@Test
public void handleMergedBlockMetaRequest() throws Exception {
RpcHandler.MergedBlockMetaReqHandler metaHandler = (client, request, callback) -> {
if (request.shuffleId != -1 && request.reduceId != -1) {
callback.onSuccess(2, mock(ManagedBuffer.class));
} else {
callback.onFailure(new RuntimeException("empty block"));
}
};
RpcHandler rpcHandler = new RpcHandler() {
@Override
public void receive(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {}
@Override
public StreamManager getStreamManager() {
return null;
}
@Override
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
return metaHandler;
}
};
Channel channel = mock(Channel.class);
List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs = new ArrayList<>();
when(channel.writeAndFlush(any())).thenAnswer(invocationOnMock0 -> {
Object response = invocationOnMock0.getArguments()[0];
ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
return channelFuture;
});
TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L, null);
MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0);
requestHandler.handle(validMetaReq);
assertEquals(1, responseAndPromisePairs.size());
assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess);
assertEquals(2,
((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks());
MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1);
requestHandler.handle(invalidMetaReq);
assertEquals(2, responseAndPromisePairs.size());
assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure);
}
}

View file

@ -23,17 +23,20 @@ import java.nio.ByteBuffer;
import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.MergedBlockMetaSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
@ -167,4 +170,40 @@ public class TransportResponseHandlerSuite {
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
}
@Test
public void handleSuccessfulMergedBlockMeta() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
handler.addRpcRequest(13, callback);
assertEquals(1, handler.numOutstandingRequests());
// This response should be ignored.
handler.handle(new MergedBlockMetaSuccess(22, 2,
new NioManagedBuffer(ByteBuffer.allocate(7))));
assertEquals(1, handler.numOutstandingRequests());
ByteBuffer resp = ByteBuffer.allocate(10);
handler.handle(new MergedBlockMetaSuccess(13, 2, new NioManagedBuffer(resp)));
ArgumentCaptor<NioManagedBuffer> bufferCaptor = ArgumentCaptor.forClass(NioManagedBuffer.class);
verify(callback, times(1)).onSuccess(eq(2), bufferCaptor.capture());
assertEquals(resp, bufferCaptor.getValue().nioByteBuffer());
assertEquals(0, handler.numOutstandingRequests());
}
@Test
public void handleFailedMergedBlockMeta() throws Exception {
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
handler.addRpcRequest(51, callback);
assertEquals(1, handler.numOutstandingRequests());
// This response should be ignored.
handler.handle(new RpcFailure(6, "failed"));
assertEquals(1, handler.numOutstandingRequests());
handler.handle(new RpcFailure(51, "failed"));
verify(callback, times(1)).onFailure(any());
assertEquals(0, handler.numOutstandingRequests());
}
}

View file

@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.protocol;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.nio.file.Files;
import java.util.List;
import com.google.common.collect.Lists;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import org.junit.Assert;
import org.junit.Test;
import org.roaringbitmap.RoaringBitmap;
import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.TransportConf;
/**
* Test for {@link MergedBlockMetaSuccess}.
*/
public class MergedBlockMetaSuccessSuite {
@Test
public void testMergedBlocksMetaEncodeDecode() throws Exception {
File chunkMetaFile = new File("target/mergedBlockMetaTest");
Files.deleteIfExists(chunkMetaFile.toPath());
RoaringBitmap chunk1 = new RoaringBitmap();
chunk1.add(1);
chunk1.add(3);
RoaringBitmap chunk2 = new RoaringBitmap();
chunk2.add(2);
chunk2.add(4);
RoaringBitmap[] expectedChunks = new RoaringBitmap[]{chunk1, chunk2};
try (DataOutputStream metaOutput = new DataOutputStream(new FileOutputStream(chunkMetaFile))) {
for (int i = 0; i < expectedChunks.length; i++) {
expectedChunks[i].serialize(metaOutput);
}
}
TransportConf conf = mock(TransportConf.class);
when(conf.lazyFileDescriptor()).thenReturn(false);
long requestId = 1L;
MergedBlockMetaSuccess expectedMeta = new MergedBlockMetaSuccess(requestId, 2,
new FileSegmentManagedBuffer(conf, chunkMetaFile, 0, chunkMetaFile.length()));
List<Object> out = Lists.newArrayList();
ChannelHandlerContext context = mock(ChannelHandlerContext.class);
when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT);
MessageEncoder.INSTANCE.encode(context, expectedMeta, out);
Assert.assertEquals(1, out.size());
MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0);
ByteArrayWritableChannel writableChannel =
new ByteArrayWritableChannel((int) msgWithHeader.count());
while (msgWithHeader.transfered() < msgWithHeader.count()) {
msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered());
}
ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData());
messageBuf.readLong(); // frame length
MessageDecoder.INSTANCE.decode(mock(ChannelHandlerContext.class), messageBuf, out);
Assert.assertEquals(1, out.size());
MergedBlockMetaSuccess decoded = (MergedBlockMetaSuccess) out.get(0);
Assert.assertEquals("merged block", expectedMeta.requestId, decoded.requestId);
Assert.assertEquals("num chunks", expectedMeta.getNumChunks(), decoded.getNumChunks());
ByteBuf responseBuf = Unpooled.wrappedBuffer(decoded.body().nioByteBuffer());
RoaringBitmap[] responseBitmaps = new RoaringBitmap[expectedMeta.getNumChunks()];
for (int i = 0; i < expectedMeta.getNumChunks(); i++) {
responseBitmaps[i] = Encoders.Bitmaps.decode(responseBuf);
}
Assert.assertEquals(
"num of roaring bitmaps", expectedMeta.getNumChunks(), responseBitmaps.length);
for (int i = 0; i < expectedMeta.getNumChunks(); i++) {
Assert.assertEquals("chunk bitmap " + i, expectedChunks[i], responseBitmaps[i]);
}
Files.delete(chunkMetaFile.toPath());
}
}

View file

@ -178,4 +178,24 @@ public abstract class BlockStoreClient implements Closeable {
MergeFinalizerListener listener) {
throw new UnsupportedOperationException();
}
/**
* Get the meta information of a merged block from the remote shuffle service.
*
* @param host the host of the remote node.
* @param port the port of the remote node.
* @param shuffleId shuffle id.
* @param reduceId reduce id.
* @param listener the listener to receive chunk counts.
*
* @since 3.2.0
*/
public void getMergedBlockMeta(
String host,
int port,
int shuffleId,
int reduceId,
MergedBlocksMetaListener listener) {
throw new UnsupportedOperationException();
}
}

View file

@ -17,12 +17,14 @@
package org.apache.spark.network.shuffle;
import com.google.common.base.Preconditions;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import com.codahale.metrics.Gauge;
@ -31,14 +33,17 @@ import com.codahale.metrics.Metric;
import com.codahale.metrics.MetricSet;
import com.codahale.metrics.Timer;
import com.codahale.metrics.Counter;
import com.google.common.collect.Sets;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
@ -54,8 +59,12 @@ import org.apache.spark.network.util.TransportConf;
* Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
* is equivalent to one block.
*/
public class ExternalBlockHandler extends RpcHandler {
public class ExternalBlockHandler extends RpcHandler
implements RpcHandler.MergedBlockMetaReqHandler {
private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class);
private static final String SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger";
private static final String SHUFFLE_BLOCK_ID = "shuffle";
private static final String SHUFFLE_CHUNK_ID = "shuffleChunk";
@VisibleForTesting
final ExternalShuffleBlockResolver blockManager;
@ -128,24 +137,23 @@ public class ExternalBlockHandler extends RpcHandler {
BlockTransferMessage msgObj,
TransportClient client,
RpcResponseCallback callback) {
if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) {
final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
try {
int numBlockIds;
long streamId;
if (msgObj instanceof FetchShuffleBlocks) {
FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
if (msgObj instanceof AbstractFetchShuffleBlocks) {
AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj;
checkAuth(client, msg.appId);
numBlockIds = 0;
if (msg.batchFetchEnabled) {
numBlockIds = msg.mapIds.length;
numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks();
Iterator<ManagedBuffer> iterator;
if (msgObj instanceof FetchShuffleBlocks) {
iterator = new ShuffleManagedBufferIterator((FetchShuffleBlocks)msgObj);
} else {
for (int[] ids: msg.reduceIds) {
numBlockIds += ids.length;
}
iterator = new ShuffleChunkManagedBufferIterator((FetchShuffleBlockChunks) msgObj);
}
streamId = streamManager.registerStream(client.getClientId(),
new ShuffleManagedBufferIterator(msg), client.getChannel());
streamId = streamManager.registerStream(client.getClientId(), iterator,
client.getChannel());
} else {
// For the compatibility with the old version, still keep the support for OpenBlocks.
OpenBlocks msg = (OpenBlocks) msgObj;
@ -189,9 +197,14 @@ public class ExternalBlockHandler extends RpcHandler {
} else if (msgObj instanceof GetLocalDirsForExecutors) {
GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
checkAuth(client, msg.appId);
Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
Set<String> execIdsForBlockResolver = Sets.newHashSet(msg.execIds);
boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER);
Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId,
execIdsForBlockResolver);
if (fetchMergedBlockDirs) {
localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId));
}
callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
} else if (msgObj instanceof FinalizeShuffleMerge) {
final Timer.Context responseDelayContext =
metrics.finalizeShuffleMergeLatencyMillis.time();
@ -211,6 +224,32 @@ public class ExternalBlockHandler extends RpcHandler {
}
}
@Override
public void receiveMergeBlockMetaReq(
TransportClient client,
MergedBlockMetaRequest metaRequest,
MergedBlockMetaResponseCallback callback) {
final Timer.Context responseDelayContext = metrics.fetchMergedBlocksMetaLatencyMillis.time();
try {
checkAuth(client, metaRequest.appId);
MergedBlockMeta mergedMeta =
mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId,
metaRequest.reduceId);
logger.debug(
"Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ",
metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId,
mergedMeta.getNumChunks());
callback.onSuccess(mergedMeta.getNumChunks(), mergedMeta.getChunksBitmapBuffer());
} finally {
responseDelayContext.stop();
}
}
@Override
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
return this;
}
@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
metrics.caughtExceptions.inc();
@ -262,6 +301,8 @@ public class ExternalBlockHandler extends RpcHandler {
private final Timer openBlockRequestLatencyMillis = new Timer();
// Time latency for executor registration latency in ms
private final Timer registerExecutorRequestLatencyMillis = new Timer();
// Time latency for processing fetch merged blocks meta request latency in ms
private final Timer fetchMergedBlocksMetaLatencyMillis = new Timer();
// Time latency for processing finalize shuffle merge request latency in ms
private final Timer finalizeShuffleMergeLatencyMillis = new Timer();
// Block transfer rate in byte per second
@ -275,6 +316,7 @@ public class ExternalBlockHandler extends RpcHandler {
allMetrics = new HashMap<>();
allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
allMetrics.put("fetchMergedBlocksMetaLatencyMillis", fetchMergedBlocksMetaLatencyMillis);
allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis);
allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
allMetrics.put("registeredExecutorsSize",
@ -294,18 +336,26 @@ public class ExternalBlockHandler extends RpcHandler {
private int index = 0;
private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
private final int size;
private boolean requestForMergedBlockChunks;
ManagedBufferIterator(OpenBlocks msg) {
String appId = msg.appId;
String execId = msg.execId;
String[] blockIds = msg.blockIds;
String[] blockId0Parts = blockIds[0].split("_");
if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) {
if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_BLOCK_ID)) {
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
size = mapIdAndReduceIds.length;
blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
} else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
requestForMergedBlockChunks = true;
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
size = reduceIdAndChunkIds.length;
blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId,
reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]);
} else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
size = rddAndSplitIds.length;
@ -330,20 +380,26 @@ public class ExternalBlockHandler extends RpcHandler {
}
private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
// For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds.
// For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds.
final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length];
for (int i = 0; i < blockIds.length; i++) {
String[] blockIdParts = blockIds[i].split("_");
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
if (blockIdParts.length != 4
|| (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID))
|| (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) {
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
}
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockIds[i]);
}
mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
// For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId.
primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]);
// For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId.
primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
}
return mapIdAndReduceIds;
return primaryIdAndSecondaryIds;
}
@Override
@ -413,6 +469,47 @@ public class ExternalBlockHandler extends RpcHandler {
}
}
private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
private int reduceIdx = 0;
private int chunkIdx = 0;
private final String appId;
private final int shuffleId;
private final int[] reduceIds;
private final int[][] chunkIds;
ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
appId = msg.appId;
shuffleId = msg.shuffleId;
reduceIds = msg.reduceIds;
chunkIds = msg.chunkIds;
// reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
// must have non-empty reduceIds and chunkIds, see the checking logic in
// OneForOneBlockFetcher.
assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
}
@Override
public boolean hasNext() {
return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
}
@Override
public ManagedBuffer next() {
ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData(
appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]));
if (chunkIdx < chunkIds[reduceIdx].length - 1) {
chunkIdx += 1;
} else {
chunkIdx = 0;
reduceIdx += 1;
}
metrics.blockTransferRateBytes.mark(block.size());
return block;
}
}
/**
* Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle
* is not enabled.

View file

@ -31,6 +31,7 @@ import com.google.common.collect.Lists;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
@ -187,6 +188,37 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
}
}
@Override
public void getMergedBlockMeta(
String host,
int port,
int shuffleId,
int reduceId,
MergedBlocksMetaListener listener) {
checkInit();
logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port,
shuffleId, reduceId);
try {
TransportClient client = clientFactory.createClient(host, port);
client.sendMergedBlockMetaReq(appId, shuffleId, reduceId,
new MergedBlockMetaResponseCallback() {
@Override
public void onSuccess(int numChunks, ManagedBuffer buffer) {
logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}",
shuffleId, reduceId);
listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer));
}
@Override
public void onFailure(Throwable e) {
listener.onFailure(shuffleId, reduceId, e);
}
});
} catch (Exception e) {
listener.onFailure(shuffleId, reduceId, e);
}
}
@Override
public MetricSet shuffleMetrics() {
checkInit();

View file

@ -361,8 +361,8 @@ public class ExternalShuffleBlockResolver {
return numRemovedBlocks;
}
public Map<String, String[]> getLocalDirs(String appId, String[] execIds) {
return Arrays.stream(execIds)
public Map<String, String[]> getLocalDirs(String appId, Set<String> execIds) {
return execIds.stream()
.map(exec -> {
ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec));
if (info == null) {

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.util.EventListener;
/**
* Listener for receiving success or failure events when fetching meta of merged blocks.
*
* @since 3.2.0
*/
public interface MergedBlocksMetaListener extends EventListener {
/**
* Called after successfully receiving the meta of a merged block.
*
* @param shuffleId shuffle id.
* @param reduceId reduce id.
* @param meta contains meta information of a merged block.
*/
void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta);
/**
* Called when there is an exception while fetching the meta of a merged block.
*
* @param shuffleId shuffle id.
* @param reduceId reduce id.
* @param exception exception getting chunk counts.
*/
void onFailure(int shuffleId, int reduceId, Throwable exception);
}

View file

@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Set;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
@ -34,8 +35,10 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.shuffle.protocol.AbstractFetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.TransportConf;
@ -51,6 +54,8 @@ import org.apache.spark.network.util.TransportConf;
*/
public class OneForOneBlockFetcher {
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_";
private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_";
private final TransportClient client;
private final BlockTransferMessage message;
@ -88,74 +93,113 @@ public class OneForOneBlockFetcher {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
this.blockIds = new String[blockIds.length];
this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
} else {
this.blockIds = blockIds;
this.message = new OpenBlocks(appId, execId, blockIds);
}
}
private boolean isShuffleBlocks(String[] blockIds) {
for (String blockId : blockIds) {
if (!blockId.startsWith("shuffle_")) {
return false;
}
/**
* Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
* the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
* IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
* all unmerged shuffle blocks or all merged shuffle chunks.
* @param blockIds block ID array
* @return whether the array contains only shuffle block IDs
*/
private boolean areShuffleBlocksOrChunks(String[] blockIds) {
if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
// It comes here because there is a blockId which doesn't have "shuffle_" prefix so we
// check if all the block ids are shuffle chunk Ids.
return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
}
return true;
}
/** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
String appId,
String execId,
String[] blockIds) {
if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
} else {
return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
}
}
/**
* Create FetchShuffleBlocks message and rebuild internal blockIds by
* analyzing the pass in blockIds.
* Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
* analyzing the passed in blockIds.
*/
private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
String appId, String execId, String[] blockIds) {
private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
String appId,
String execId,
String[] blockIds,
boolean areMergedChunks) {
String[] firstBlock = splitBlockId(blockIds[0]);
int shuffleId = Integer.parseInt(firstBlock[1]);
boolean batchFetchEnabled = firstBlock.length == 5;
LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
// In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
// is reduceId.
LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
for (String blockId : blockIds) {
String[] blockIdParts = splitBlockId(blockId);
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
long mapId = Long.parseLong(blockIdParts[2]);
if (!mapIdToBlocksInfo.containsKey(mapId)) {
mapIdToBlocksInfo.put(mapId, new BlocksInfo());
Number primaryId;
if (!areMergedChunks) {
primaryId = Long.parseLong(blockIdParts[2]);
} else {
primaryId = Integer.parseInt(blockIdParts[2]);
}
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
blocksInfoByMapId.blockIds.add(blockId);
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId,
id -> new BlocksInfo());
blocksInfoByPrimaryId.blockIds.add(blockId);
// If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a
// shuffleChunk block, then blockIdParts[3] = chunkId
blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
if (batchFetchEnabled) {
// It comes here only if the blockId is a regular shuffle block not a shuffleChunk block.
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
// FetchShuffleBlocks to store the start and end reduce id for range
// [startReduceId, endReduceId).
assert(blockIdParts.length == 5);
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
// blockIdParts[4] is the end reduce id for the batch range
blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
}
}
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
// In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
// secondaryIds are chunkIds.
int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
int blockIdIndex = 0;
for (int i = 0; i < mapIds.length; i++) {
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
int secIndex = 0;
for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) {
secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids);
// The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
// because the shuffle data's return order should match the `blockIds`'s order to ensure
// blockId and data match.
for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
// The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
// FetchShuffleBlockChunks because the shuffle data's return order should match the
// `blockIds`'s order to ensure blockId and data match.
for (String blockId : blocksInfo.blockIds) {
this.blockIds[blockIdIndex++] = blockId;
}
}
assert(blockIdIndex == this.blockIds.length);
return new FetchShuffleBlocks(
appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
if (!areMergedChunks) {
long[] mapIds = Longs.toArray(primaryIds);
return new FetchShuffleBlocks(
appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled);
} else {
int[] reduceIds = Ints.toArray(primaryIds);
return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray);
}
}
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
@ -163,7 +207,17 @@ public class OneForOneBlockFetcher {
String[] blockIdParts = blockId.split("_");
// For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
// For single block id, the format contains shuffleId, mapId, educeId.
if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
// For single block chunk id, the format contains shuffleId, reduceId, chunkId.
if (blockIdParts.length < 4 || blockIdParts.length > 5) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
if (blockIdParts.length == 4 &&
!(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
@ -173,11 +227,15 @@ public class OneForOneBlockFetcher {
/** The reduceIds and blocks in a single mapId */
private class BlocksInfo {
final ArrayList<Integer> reduceIds;
/**
* For {@link FetchShuffleBlocks} message, the ids are reduceIds.
* For {@link FetchShuffleBlockChunks} message, the ids are chunkIds.
*/
final ArrayList<Integer> ids;
final ArrayList<String> blockIds;
BlocksInfo() {
this.reduceIds = new ArrayList<>();
this.ids = new ArrayList<>();
this.blockIds = new ArrayList<>();
}
}

View file

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import org.apache.spark.network.protocol.Encoders;
/**
* Base class for fetch shuffle blocks and chunks.
*
* @since 3.2.0
*/
public abstract class AbstractFetchShuffleBlocks extends BlockTransferMessage {
public final String appId;
public final String execId;
public final int shuffleId;
protected AbstractFetchShuffleBlocks(
String appId,
String execId,
int shuffleId) {
this.appId = appId;
this.execId = execId;
this.shuffleId = shuffleId;
}
public ToStringBuilder toStringHelper() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
.append("appId", appId)
.append("execId", execId)
.append("shuffleId", shuffleId);
}
/**
* Returns number of blocks in the request.
*/
public abstract int getNumBlocks();
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AbstractFetchShuffleBlocks that = (AbstractFetchShuffleBlocks) o;
return shuffleId == that.shuffleId
&& Objects.equal(appId, that.appId) && Objects.equal(execId, that.execId);
}
@Override
public int hashCode() {
int result = appId.hashCode();
result = 31 * result + execId.hashCode();
result = 31 * result + shuffleId;
return result;
}
@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId)
+ Encoders.Strings.encodedLength(execId)
+ 4; /* encoded length of shuffleId */
}
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
Encoders.Strings.encode(buf, execId);
buf.writeInt(shuffleId);
}
}

View file

@ -48,7 +48,8 @@ public abstract class BlockTransferMessage implements Encodable {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11),
PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14);
PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14),
FETCH_SHUFFLE_BLOCK_CHUNKS(15);
private final byte id;
@ -82,6 +83,7 @@ public abstract class BlockTransferMessage implements Encodable {
case 12: return PushBlockStream.decode(buf);
case 13: return FinalizeShuffleMerge.decode(buf);
case 14: return MergeStatuses.decode(buf);
case 15: return FetchShuffleBlockChunks.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}

View file

@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import java.util.Arrays;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encoders;
// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
/**
* Request to read a set of block chunks. Returns {@link StreamHandle}.
*
* @since 3.2.0
*/
public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
// The length of reduceIds must equal to chunkIds.size().
public final int[] reduceIds;
// The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds.
public final int[][] chunkIds;
public FetchShuffleBlockChunks(
String appId,
String execId,
int shuffleId,
int[] reduceIds,
int[][] chunkIds) {
super(appId, execId, shuffleId);
this.reduceIds = reduceIds;
this.chunkIds = chunkIds;
assert(reduceIds.length == chunkIds.length);
}
@Override
protected Type type() { return Type.FETCH_SHUFFLE_BLOCK_CHUNKS; }
@Override
public String toString() {
return toStringHelper()
.append("reduceIds", Arrays.toString(reduceIds))
.append("chunkIds", Arrays.deepToString(chunkIds))
.toString();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o;
if (!super.equals(that)) return false;
if (!Arrays.equals(reduceIds, that.reduceIds)) return false;
return Arrays.deepEquals(chunkIds, that.chunkIds);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(reduceIds);
result = 31 * result + Arrays.deepHashCode(chunkIds);
return result;
}
@Override
public int encodedLength() {
int encodedLengthOfChunkIds = 0;
for (int[] ids: chunkIds) {
encodedLengthOfChunkIds += Encoders.IntArrays.encodedLength(ids);
}
return super.encodedLength()
+ Encoders.IntArrays.encodedLength(reduceIds)
+ 4 /* encoded length of chunkIds.size() */
+ encodedLengthOfChunkIds;
}
@Override
public void encode(ByteBuf buf) {
super.encode(buf);
Encoders.IntArrays.encode(buf, reduceIds);
// Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the
// interest of forward compatibility.
buf.writeInt(chunkIds.length);
for (int[] ids: chunkIds) {
Encoders.IntArrays.encode(buf, ids);
}
}
@Override
public int getNumBlocks() {
int numBlocks = 0;
for (int[] ids : chunkIds) {
numBlocks += ids.length;
}
return numBlocks;
}
public static FetchShuffleBlockChunks decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
String execId = Encoders.Strings.decode(buf);
int shuffleId = buf.readInt();
int[] reduceIds = Encoders.IntArrays.decode(buf);
int chunkIdsLen = buf.readInt();
int[][] chunkIds = new int[chunkIdsLen][];
for (int i = 0; i < chunkIdsLen; i++) {
chunkIds[i] = Encoders.IntArrays.decode(buf);
}
return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds);
}
}

View file

@ -20,8 +20,6 @@ package org.apache.spark.network.shuffle.protocol;
import java.util.Arrays;
import io.netty.buffer.ByteBuf;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import org.apache.spark.network.protocol.Encoders;
@ -29,10 +27,7 @@ import org.apache.spark.network.protocol.Encoders;
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
/** Request to read a set of blocks. Returns {@link StreamHandle}. */
public class FetchShuffleBlocks extends BlockTransferMessage {
public final String appId;
public final String execId;
public final int shuffleId;
public class FetchShuffleBlocks extends AbstractFetchShuffleBlocks {
// The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds,
// it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id.
public final long[] mapIds;
@ -50,9 +45,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
long[] mapIds,
int[][] reduceIds,
boolean batchFetchEnabled) {
this.appId = appId;
this.execId = execId;
this.shuffleId = shuffleId;
super(appId, execId, shuffleId);
this.mapIds = mapIds;
this.reduceIds = reduceIds;
assert(mapIds.length == reduceIds.length);
@ -69,10 +62,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
.append("appId", appId)
.append("execId", execId)
.append("shuffleId", shuffleId)
return toStringHelper()
.append("mapIds", Arrays.toString(mapIds))
.append("reduceIds", Arrays.deepToString(reduceIds))
.append("batchFetchEnabled", batchFetchEnabled)
@ -85,35 +75,40 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
if (o == null || getClass() != o.getClass()) return false;
FetchShuffleBlocks that = (FetchShuffleBlocks) o;
if (shuffleId != that.shuffleId) return false;
if (!super.equals(that)) return false;
if (batchFetchEnabled != that.batchFetchEnabled) return false;
if (!appId.equals(that.appId)) return false;
if (!execId.equals(that.execId)) return false;
if (!Arrays.equals(mapIds, that.mapIds)) return false;
return Arrays.deepEquals(reduceIds, that.reduceIds);
}
@Override
public int hashCode() {
int result = appId.hashCode();
result = 31 * result + execId.hashCode();
result = 31 * result + shuffleId;
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(mapIds);
result = 31 * result + Arrays.deepHashCode(reduceIds);
result = 31 * result + (batchFetchEnabled ? 1 : 0);
return result;
}
@Override
public int getNumBlocks() {
if (batchFetchEnabled) {
return mapIds.length;
}
int numBlocks = 0;
for (int[] ids : reduceIds) {
numBlocks += ids.length;
}
return numBlocks;
}
@Override
public int encodedLength() {
int encodedLengthOfReduceIds = 0;
for (int[] ids: reduceIds) {
encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids);
}
return Encoders.Strings.encodedLength(appId)
+ Encoders.Strings.encodedLength(execId)
+ 4 /* encoded length of shuffleId */
return super.encodedLength()
+ Encoders.LongArrays.encodedLength(mapIds)
+ 4 /* encoded length of reduceIds.size() */
+ encodedLengthOfReduceIds
@ -122,9 +117,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
Encoders.Strings.encode(buf, execId);
buf.writeInt(shuffleId);
super.encode(buf);
Encoders.LongArrays.encode(buf, mapIds);
buf.writeInt(reduceIds.length);
for (int[] ids: reduceIds) {

View file

@ -34,13 +34,16 @@ import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
import org.apache.spark.network.shuffle.protocol.MergeStatuses;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
@ -258,4 +261,113 @@ public class ExternalBlockHandlerSuite {
.get("finalizeShuffleMergeLatencyMillis");
assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount());
}
@Test
public void testFetchMergedBlocksMeta() {
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn(
new MergedBlockMeta(1, mock(ManagedBuffer.class)));
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn(
new MergedBlockMeta(3, mock(ManagedBuffer.class)));
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn(
new MergedBlockMeta(5, mock(ManagedBuffer.class)));
int[] expectedCount = new int[]{1, 3, 5};
String appId = "app0";
long requestId = 0L;
for (int reduceId = 0; reduceId < 3; reduceId++) {
MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId);
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
handler.getMergedBlockMetaReqHandler()
.receiveMergeBlockMetaReq(client, req, callback);
verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId);
ArgumentCaptor<Integer> numChunksResponse = ArgumentCaptor.forClass(Integer.class);
ArgumentCaptor<ManagedBuffer> chunkBitmapResponse =
ArgumentCaptor.forClass(ManagedBuffer.class);
verify(callback, times(1)).onSuccess(numChunksResponse.capture(),
chunkBitmapResponse.capture());
assertEquals("num chunks in merged block " + reduceId, expectedCount[reduceId],
numChunksResponse.getValue().intValue());
assertNotNull("chunks bitmap buffer " + reduceId, chunkBitmapResponse.getValue());
}
}
@Test
public void testOpenBlocksWithShuffleChunks() {
verifyBlockChunkFetches(true);
}
@Test
public void testFetchShuffleChunks() {
verifyBlockChunkFetches(false);
}
private void verifyBlockChunkFetches(boolean useOpenBlocks) {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
ByteBuffer buffer;
if (useOpenBlocks) {
OpenBlocks openBlocks =
new OpenBlocks("app0", "exec1",
new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0",
"shuffleChunk_0_1_1"});
buffer = openBlocks.toByteBuffer();
} else {
FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks(
"app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}});
buffer = fetchChunks.toByteBuffer();
}
ManagedBuffer[][] buffers = new ManagedBuffer[][] {
{
new NioManagedBuffer(ByteBuffer.wrap(new byte[5])),
new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
},
{
new NioManagedBuffer(ByteBuffer.wrap(new byte[5])),
new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
}
};
for (int reduceId = 0; reduceId < 2; reduceId++) {
for (int chunkId = 0; chunkId < 2; chunkId++) {
when(mergedShuffleManager.getMergedBlockData(
"app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]);
}
}
handler.receive(client, buffer, callback);
ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(callback, times(1)).onSuccess(response.capture());
verify(callback, never()).onFailure(any());
StreamHandle handle =
(StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
assertEquals(4, handle.numChunks);
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(any(), stream.capture(), any());
Iterator<ManagedBuffer> bufferIter = stream.getValue();
for (int reduceId = 0; reduceId < 2; reduceId++) {
for (int chunkId = 0; chunkId < 2; chunkId++) {
assertEquals(buffers[reduceId][chunkId], bufferIter.next());
}
}
assertFalse(bufferIter.hasNext());
verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt());
verify(blockResolver, never()).getBlockData(
anyString(), anyString(), anyInt(), anyInt(), anyInt());
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0);
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1);
// Verify open block request latency metrics
Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
.getAllMetrics()
.getMetrics()
.get("openBlockRequestLatencyMillis");
assertEquals(1, openBlockRequestLatencyMillis.getCount());
// Verify block transfer metrics
Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler)
.getAllMetrics()
.getMetrics()
.get("blockTransferRateBytes");
assertEquals(24, blockTransferRateBytes.getCount());
}
}

View file

@ -27,8 +27,7 @@ import com.google.common.collect.Maps;
import io.netty.buffer.Unpooled;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
@ -46,6 +45,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.MapConfigProvider;
@ -243,6 +243,57 @@ public class OneForOneBlockFetcherSuite {
}
}
@Test
public void testShuffleBlockChunksFetch() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new int[][] {{ 0, 1, 2 }}), conf);
for (int i = 0; i < 3; i ++) {
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i,
blocks.get("shuffleChunk_0_0_" + i));
}
}
@Test
public void testShuffleBlockChunkFetchFailure() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
blocks.put("shuffleChunk_0_0_1", null);
blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}),
conf);
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0",
blocks.get("shuffleChunk_0_0_0"));
verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any());
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2",
blocks.get("shuffleChunk_0_0_2"));
}
@Test
public void testInvalidShuffleBlockIds() {
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
new String[]{"shuffle_0_0"},
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new int[][] {{ 0 }}), conf));
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
new String[]{"shuffleChunk_0_0_0_0_0"},
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new int[][] {{ 0 }}), conf));
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
new String[]{"shuffleChunk_0_0_0_0"},
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
new int[][] {{ 0 }}), conf));
}
/**
* Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
* simply returns the given (BlockId, Block) pairs.

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
public class FetchShuffleBlockChunksSuite {
@Test
public void testFetchShuffleBlockChunksEncodeDecode() {
FetchShuffleBlockChunks shuffleBlockChunks =
new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}});
Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks());
int len = shuffleBlockChunks.encodedLength();
Assert.assertEquals(45, len);
ByteBuf buf = Unpooled.buffer(len);
shuffleBlockChunks.encode(buf);
FetchShuffleBlockChunks decoded = FetchShuffleBlockChunks.decode(buf);
assertEquals(shuffleBlockChunks, decoded);
}
}

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle.protocol;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
public class FetchShuffleBlocksSuite {
@Test
public void testFetchShuffleBlockEncodeDecode() {
FetchShuffleBlocks fetchShuffleBlocks =
new FetchShuffleBlocks("app0", "exec1", 0, new long[] {0}, new int[][] {{0, 1}}, false);
Assert.assertEquals(2, fetchShuffleBlocks.getNumBlocks());
int len = fetchShuffleBlocks.encodedLength();
Assert.assertEquals(50, len);
ByteBuf buf = Unpooled.buffer(len);
fetchShuffleBlocks.encode(buf);
FetchShuffleBlocks decoded = FetchShuffleBlocks.decode(buf);
assertEquals(fetchShuffleBlocks, decoded);
}
}

View file

@ -62,7 +62,8 @@ class ExternalShuffleServiceMetricsSuite extends SparkFunSuite {
"registerExecutorRequestLatencyMillis",
"shuffle-server.usedDirectMemory",
"shuffle-server.usedHeapMemory",
"finalizeShuffleMergeLatencyMillis")
"finalizeShuffleMergeLatencyMillis",
"fetchMergedBlocksMetaLatencyMillis")
)
}
}

View file

@ -40,7 +40,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers {
val allMetrics = Set(
"openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis",
"blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections",
"numCaughtExceptions", "finalizeShuffleMergeLatencyMillis")
"numCaughtExceptions", "finalizeShuffleMergeLatencyMillis",
"fetchMergedBlocksMetaLatencyMillis")
metrics.getMetrics.keySet().asScala should be (allMetrics)
}

View file

@ -412,7 +412,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
"registerExecutorRequestLatencyMillis",
"finalizeShuffleMergeLatencyMillis",
"shuffle-server.usedDirectMemory",
"shuffle-server.usedHeapMemory"
"shuffle-server.usedHeapMemory",
"fetchMergedBlocksMetaLatencyMillis"
))
}