[SPARK-35671][SHUFFLE][CORE] Add support in the ESS to serve merged shuffle block meta and data to executors
### What changes were proposed in this pull request? This adds support in the ESS to serve merged shuffle block meta and data requests to executors. This change is needed for fetching remote merged shuffle data from the remote shuffle services. This is part of push-based shuffle SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). This change introduces new messages between clients and the external shuffle service: 1. `MergedBlockMetaRequest`: The client sends this to external shuffle to get the meta information for a merged block. The response to this is one of these : - `MergedBlockMetaSuccess` : contains request id, number of chunks, and a `ManagedBuffer` which is a `FileSegmentBuffer` backed by the merged block meta file. - `RpcFailure`: this is sent back to client in case of failure. This is an existing message. 2. `FetchShuffleBlockChunks`: This is similar to `FetchShuffleBlocks` message but it is to fetch merged shuffle chunks instead of blocks. ### Why are the changes needed? These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests. The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). We have already verified the functionality and the improved performance as documented in the SPIP doc. Lead-authored-by: Chandni Singh chsinghlinkedin.com Co-authored-by: Min Shen mshenlinkedin.com Closes #32811 from otterc/SPARK-35671. Lead-authored-by: Chandni Singh <singh.chandni@gmail.com> Co-authored-by: Min Shen <mshen@linkedin.com> Co-authored-by: Chandni Singh <chsingh@linkedin.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
This commit is contained in:
parent
af20474c67
commit
8ce1e344e5
|
@ -0,0 +1,31 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.client;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A basic callback. This is extended by {@link RpcResponseCallback} and
|
||||||
|
* {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests
|
||||||
|
* can be handled in {@link TransportResponseHandler} a similar way.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public interface BaseResponseCallback {
|
||||||
|
|
||||||
|
/** Exception either propagated from server or raised on client side. */
|
||||||
|
void onFailure(Throwable e);
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.client;
|
||||||
|
|
||||||
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback for the result of a single
|
||||||
|
* {@link org.apache.spark.network.protocol.MergedBlockMetaRequest}.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public interface MergedBlockMetaResponseCallback extends BaseResponseCallback {
|
||||||
|
/**
|
||||||
|
* Called upon receipt of a particular merged block meta.
|
||||||
|
*
|
||||||
|
* The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
|
||||||
|
* call returns. You must therefore either retain() the buffer or copy its contents before
|
||||||
|
* returning.
|
||||||
|
*
|
||||||
|
* @param numChunks number of merged chunks in the merged block
|
||||||
|
* @param buffer the buffer contains an array of roaring bitmaps. The i-th roaring bitmap
|
||||||
|
* contains the mapIds that were merged to the i-th merged chunk.
|
||||||
|
*/
|
||||||
|
void onSuccess(int numChunks, ManagedBuffer buffer);
|
||||||
|
}
|
|
@ -23,7 +23,7 @@ import java.nio.ByteBuffer;
|
||||||
* Callback for the result of a single RPC. This will be invoked once with either success or
|
* Callback for the result of a single RPC. This will be invoked once with either success or
|
||||||
* failure.
|
* failure.
|
||||||
*/
|
*/
|
||||||
public interface RpcResponseCallback {
|
public interface RpcResponseCallback extends BaseResponseCallback {
|
||||||
/**
|
/**
|
||||||
* Successful serialized result from server.
|
* Successful serialized result from server.
|
||||||
*
|
*
|
||||||
|
@ -31,7 +31,4 @@ public interface RpcResponseCallback {
|
||||||
* Please copy the content of `response` if you want to use it after `onSuccess` returns.
|
* Please copy the content of `response` if you want to use it after `onSuccess` returns.
|
||||||
*/
|
*/
|
||||||
void onSuccess(ByteBuffer response);
|
void onSuccess(ByteBuffer response);
|
||||||
|
|
||||||
/** Exception either propagated from server or raised on client side. */
|
|
||||||
void onFailure(Throwable e);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -200,6 +200,31 @@ public class TransportClient implements Closeable {
|
||||||
return requestId;
|
return requestId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a MergedBlockMetaRequest message to the server. The response of this message is
|
||||||
|
* either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}.
|
||||||
|
*
|
||||||
|
* @param appId applicationId.
|
||||||
|
* @param shuffleId shuffle id.
|
||||||
|
* @param reduceId reduce id.
|
||||||
|
* @param callback callback the handle the reply.
|
||||||
|
*/
|
||||||
|
public void sendMergedBlockMetaReq(
|
||||||
|
String appId,
|
||||||
|
int shuffleId,
|
||||||
|
int reduceId,
|
||||||
|
MergedBlockMetaResponseCallback callback) {
|
||||||
|
long requestId = requestId();
|
||||||
|
if (logger.isTraceEnabled()) {
|
||||||
|
logger.trace(
|
||||||
|
"Sending RPC {} to fetch merged block meta to {}", requestId, getRemoteAddress(channel));
|
||||||
|
}
|
||||||
|
handler.addRpcRequest(requestId, callback);
|
||||||
|
RpcChannelListener listener = new RpcChannelListener(requestId, callback);
|
||||||
|
channel.writeAndFlush(
|
||||||
|
new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Send data to the remote end as a stream. This differs from stream() in that this is a request
|
* Send data to the remote end as a stream. This differs from stream() in that this is a request
|
||||||
* to *send* data to the remote end, not to receive it from the remote.
|
* to *send* data to the remote end, not to receive it from the remote.
|
||||||
|
@ -349,9 +374,9 @@ public class TransportClient implements Closeable {
|
||||||
|
|
||||||
private class RpcChannelListener extends StdChannelListener {
|
private class RpcChannelListener extends StdChannelListener {
|
||||||
final long rpcRequestId;
|
final long rpcRequestId;
|
||||||
final RpcResponseCallback callback;
|
final BaseResponseCallback callback;
|
||||||
|
|
||||||
RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
|
RpcChannelListener(long rpcRequestId, BaseResponseCallback callback) {
|
||||||
super("RPC " + rpcRequestId);
|
super("RPC " + rpcRequestId);
|
||||||
this.rpcRequestId = rpcRequestId;
|
this.rpcRequestId = rpcRequestId;
|
||||||
this.callback = callback;
|
this.callback = callback;
|
||||||
|
|
|
@ -33,6 +33,7 @@ import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import org.apache.spark.network.protocol.ChunkFetchFailure;
|
import org.apache.spark.network.protocol.ChunkFetchFailure;
|
||||||
import org.apache.spark.network.protocol.ChunkFetchSuccess;
|
import org.apache.spark.network.protocol.ChunkFetchSuccess;
|
||||||
|
import org.apache.spark.network.protocol.MergedBlockMetaSuccess;
|
||||||
import org.apache.spark.network.protocol.ResponseMessage;
|
import org.apache.spark.network.protocol.ResponseMessage;
|
||||||
import org.apache.spark.network.protocol.RpcFailure;
|
import org.apache.spark.network.protocol.RpcFailure;
|
||||||
import org.apache.spark.network.protocol.RpcResponse;
|
import org.apache.spark.network.protocol.RpcResponse;
|
||||||
|
@ -56,7 +57,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
||||||
|
|
||||||
private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
|
private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
|
||||||
|
|
||||||
private final Map<Long, RpcResponseCallback> outstandingRpcs;
|
private final Map<Long, BaseResponseCallback> outstandingRpcs;
|
||||||
|
|
||||||
private final Queue<Pair<String, StreamCallback>> streamCallbacks;
|
private final Queue<Pair<String, StreamCallback>> streamCallbacks;
|
||||||
private volatile boolean streamActive;
|
private volatile boolean streamActive;
|
||||||
|
@ -81,7 +82,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
||||||
outstandingFetches.remove(streamChunkId);
|
outstandingFetches.remove(streamChunkId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addRpcRequest(long requestId, RpcResponseCallback callback) {
|
public void addRpcRequest(long requestId, BaseResponseCallback callback) {
|
||||||
updateTimeOfLastRequest();
|
updateTimeOfLastRequest();
|
||||||
outstandingRpcs.put(requestId, callback);
|
outstandingRpcs.put(requestId, callback);
|
||||||
}
|
}
|
||||||
|
@ -112,7 +113,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
||||||
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
|
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
|
for (Map.Entry<Long, BaseResponseCallback> entry : outstandingRpcs.entrySet()) {
|
||||||
try {
|
try {
|
||||||
entry.getValue().onFailure(cause);
|
entry.getValue().onFailure(cause);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -184,7 +185,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
||||||
}
|
}
|
||||||
} else if (message instanceof RpcResponse) {
|
} else if (message instanceof RpcResponse) {
|
||||||
RpcResponse resp = (RpcResponse) message;
|
RpcResponse resp = (RpcResponse) message;
|
||||||
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
|
RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId);
|
||||||
if (listener == null) {
|
if (listener == null) {
|
||||||
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
|
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
|
||||||
resp.requestId, getRemoteAddress(channel), resp.body().size());
|
resp.requestId, getRemoteAddress(channel), resp.body().size());
|
||||||
|
@ -199,7 +200,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
||||||
}
|
}
|
||||||
} else if (message instanceof RpcFailure) {
|
} else if (message instanceof RpcFailure) {
|
||||||
RpcFailure resp = (RpcFailure) message;
|
RpcFailure resp = (RpcFailure) message;
|
||||||
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
|
BaseResponseCallback listener = outstandingRpcs.get(resp.requestId);
|
||||||
if (listener == null) {
|
if (listener == null) {
|
||||||
logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
|
logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
|
||||||
resp.requestId, getRemoteAddress(channel), resp.errorString);
|
resp.requestId, getRemoteAddress(channel), resp.errorString);
|
||||||
|
@ -207,6 +208,22 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
||||||
outstandingRpcs.remove(resp.requestId);
|
outstandingRpcs.remove(resp.requestId);
|
||||||
listener.onFailure(new RuntimeException(resp.errorString));
|
listener.onFailure(new RuntimeException(resp.errorString));
|
||||||
}
|
}
|
||||||
|
} else if (message instanceof MergedBlockMetaSuccess) {
|
||||||
|
MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message;
|
||||||
|
try {
|
||||||
|
MergedBlockMetaResponseCallback listener =
|
||||||
|
(MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId);
|
||||||
|
if (listener == null) {
|
||||||
|
logger.warn(
|
||||||
|
"Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not"
|
||||||
|
+ " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size());
|
||||||
|
} else {
|
||||||
|
outstandingRpcs.remove(resp.requestId);
|
||||||
|
listener.onSuccess(resp.getNumChunks(), resp.body());
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
resp.body().release();
|
||||||
|
}
|
||||||
} else if (message instanceof StreamResponse) {
|
} else if (message instanceof StreamResponse) {
|
||||||
StreamResponse resp = (StreamResponse) message;
|
StreamResponse resp = (StreamResponse) message;
|
||||||
Pair<String, StreamCallback> entry = streamCallbacks.poll();
|
Pair<String, StreamCallback> entry = streamCallbacks.poll();
|
||||||
|
|
|
@ -138,4 +138,9 @@ class AuthRpcHandler extends AbstractAuthRpcHandler {
|
||||||
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
|
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
|
||||||
|
return saslHandler.getMergedBlockMetaReqHandler();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.protocol;
|
||||||
|
|
||||||
|
import com.google.common.base.Objects;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import org.apache.commons.lang3.builder.ToStringBuilder;
|
||||||
|
import org.apache.commons.lang3.builder.ToStringStyle;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Request to find the meta information for the specified merged block. The meta information
|
||||||
|
* contains the number of chunks in the merged blocks and the maps ids in each chunk.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage {
|
||||||
|
public final long requestId;
|
||||||
|
public final String appId;
|
||||||
|
public final int shuffleId;
|
||||||
|
public final int reduceId;
|
||||||
|
|
||||||
|
public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) {
|
||||||
|
super(null, false);
|
||||||
|
this.requestId = requestId;
|
||||||
|
this.appId = appId;
|
||||||
|
this.shuffleId = shuffleId;
|
||||||
|
this.reduceId = reduceId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Type type() {
|
||||||
|
return Type.MergedBlockMetaRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int encodedLength() {
|
||||||
|
return 8 + Encoders.Strings.encodedLength(appId) + 4 + 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void encode(ByteBuf buf) {
|
||||||
|
buf.writeLong(requestId);
|
||||||
|
Encoders.Strings.encode(buf, appId);
|
||||||
|
buf.writeInt(shuffleId);
|
||||||
|
buf.writeInt(reduceId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MergedBlockMetaRequest decode(ByteBuf buf) {
|
||||||
|
long requestId = buf.readLong();
|
||||||
|
String appId = Encoders.Strings.decode(buf);
|
||||||
|
int shuffleId = buf.readInt();
|
||||||
|
int reduceId = buf.readInt();
|
||||||
|
return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hashCode(requestId, appId, shuffleId, reduceId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object other) {
|
||||||
|
if (other instanceof MergedBlockMetaRequest) {
|
||||||
|
MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
|
||||||
|
return requestId == o.requestId && shuffleId == o.shuffleId && reduceId == o.reduceId
|
||||||
|
&& Objects.equal(appId, o.appId);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
|
||||||
|
.append("requestId", requestId)
|
||||||
|
.append("appId", appId)
|
||||||
|
.append("shuffleId", shuffleId)
|
||||||
|
.append("reduceId", reduceId)
|
||||||
|
.toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.protocol;
|
||||||
|
|
||||||
|
import com.google.common.base.Objects;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import org.apache.commons.lang3.builder.ToStringBuilder;
|
||||||
|
import org.apache.commons.lang3.builder.ToStringStyle;
|
||||||
|
|
||||||
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
|
import org.apache.spark.network.buffer.NettyManagedBuffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Response to {@link MergedBlockMetaRequest} request.
|
||||||
|
* Note that the server-side encoding of this messages does NOT include the buffer itself.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public class MergedBlockMetaSuccess extends AbstractResponseMessage {
|
||||||
|
public final long requestId;
|
||||||
|
public final int numChunks;
|
||||||
|
|
||||||
|
public MergedBlockMetaSuccess(
|
||||||
|
long requestId,
|
||||||
|
int numChunks,
|
||||||
|
ManagedBuffer chunkBitmapsBuffer) {
|
||||||
|
super(chunkBitmapsBuffer, true);
|
||||||
|
this.requestId = requestId;
|
||||||
|
this.numChunks = numChunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Type type() {
|
||||||
|
return Type.MergedBlockMetaSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hashCode(requestId, numChunks);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
|
||||||
|
.append("requestId", requestId).append("numChunks", numChunks).toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int encodedLength() {
|
||||||
|
return 8 + 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
|
||||||
|
@Override
|
||||||
|
public void encode(ByteBuf buf) {
|
||||||
|
buf.writeLong(requestId);
|
||||||
|
buf.writeInt(numChunks);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumChunks() {
|
||||||
|
return numChunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Decoding uses the given ByteBuf as our data, and will retain() it. */
|
||||||
|
public static MergedBlockMetaSuccess decode(ByteBuf buf) {
|
||||||
|
long requestId = buf.readLong();
|
||||||
|
int numChunks = buf.readInt();
|
||||||
|
buf.retain();
|
||||||
|
NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
|
||||||
|
return new MergedBlockMetaSuccess(requestId, numChunks, managedBuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ResponseMessage createFailureResponse(String error) {
|
||||||
|
return new RpcFailure(requestId, error);
|
||||||
|
}
|
||||||
|
}
|
|
@ -37,7 +37,8 @@ public interface Message extends Encodable {
|
||||||
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
|
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
|
||||||
RpcRequest(3), RpcResponse(4), RpcFailure(5),
|
RpcRequest(3), RpcResponse(4), RpcFailure(5),
|
||||||
StreamRequest(6), StreamResponse(7), StreamFailure(8),
|
StreamRequest(6), StreamResponse(7), StreamFailure(8),
|
||||||
OneWayMessage(9), UploadStream(10), User(-1);
|
OneWayMessage(9), UploadStream(10), MergedBlockMetaRequest(11), MergedBlockMetaSuccess(12),
|
||||||
|
User(-1);
|
||||||
|
|
||||||
private final byte id;
|
private final byte id;
|
||||||
|
|
||||||
|
@ -66,6 +67,8 @@ public interface Message extends Encodable {
|
||||||
case 8: return StreamFailure;
|
case 8: return StreamFailure;
|
||||||
case 9: return OneWayMessage;
|
case 9: return OneWayMessage;
|
||||||
case 10: return UploadStream;
|
case 10: return UploadStream;
|
||||||
|
case 11: return MergedBlockMetaRequest;
|
||||||
|
case 12: return MergedBlockMetaSuccess;
|
||||||
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
|
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
|
||||||
default: throw new IllegalArgumentException("Unknown message type: " + id);
|
default: throw new IllegalArgumentException("Unknown message type: " + id);
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,6 +83,12 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
|
||||||
case UploadStream:
|
case UploadStream:
|
||||||
return UploadStream.decode(in);
|
return UploadStream.decode(in);
|
||||||
|
|
||||||
|
case MergedBlockMetaRequest:
|
||||||
|
return MergedBlockMetaRequest.decode(in);
|
||||||
|
|
||||||
|
case MergedBlockMetaSuccess:
|
||||||
|
return MergedBlockMetaSuccess.decode(in);
|
||||||
|
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Unexpected message type: " + msgType);
|
throw new IllegalArgumentException("Unexpected message type: " + msgType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,4 +104,9 @@ public abstract class AbstractAuthRpcHandler extends RpcHandler {
|
||||||
public boolean isAuthenticated() {
|
public boolean isAuthenticated() {
|
||||||
return isAuthenticated;
|
return isAuthenticated;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
|
||||||
|
return delegate.getMergedBlockMetaReqHandler();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,9 +22,11 @@ import java.nio.ByteBuffer;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
|
||||||
import org.apache.spark.network.client.RpcResponseCallback;
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.StreamCallbackWithID;
|
import org.apache.spark.network.client.StreamCallbackWithID;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
|
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
|
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
|
||||||
|
@ -32,6 +34,8 @@ import org.apache.spark.network.client.TransportClient;
|
||||||
public abstract class RpcHandler {
|
public abstract class RpcHandler {
|
||||||
|
|
||||||
private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
|
private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
|
||||||
|
private static final MergedBlockMetaReqHandler NOOP_MERGED_BLOCK_META_REQ_HANDLER =
|
||||||
|
new NoopMergedBlockMetaReqHandler();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
|
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
|
||||||
|
@ -100,6 +104,10 @@ public abstract class RpcHandler {
|
||||||
receive(client, message, ONE_WAY_CALLBACK);
|
receive(client, message, ONE_WAY_CALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
|
||||||
|
return NOOP_MERGED_BLOCK_META_REQ_HANDLER;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Invoked when the channel associated with the given client is active.
|
* Invoked when the channel associated with the given client is active.
|
||||||
*/
|
*/
|
||||||
|
@ -129,4 +137,40 @@ public abstract class RpcHandler {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handler for {@link MergedBlockMetaRequest}.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public interface MergedBlockMetaReqHandler {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Receive a {@link MergedBlockMetaRequest}.
|
||||||
|
*
|
||||||
|
* @param client A channel client which enables the handler to make requests back to the sender
|
||||||
|
* of this RPC.
|
||||||
|
* @param mergedBlockMetaRequest Request for merged block meta.
|
||||||
|
* @param callback Callback which should be invoked exactly once upon success or failure.
|
||||||
|
*/
|
||||||
|
void receiveMergeBlockMetaReq(
|
||||||
|
TransportClient client,
|
||||||
|
MergedBlockMetaRequest mergedBlockMetaRequest,
|
||||||
|
MergedBlockMetaResponseCallback callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Noop implementation of {@link MergedBlockMetaReqHandler}. This Noop implementation is used
|
||||||
|
* by all the RPC handlers which don't eventually delegate the {@link MergedBlockMetaRequest} to
|
||||||
|
* ExternalBlockHandler in the network-shuffle module.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
private static class NoopMergedBlockMetaReqHandler implements MergedBlockMetaReqHandler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void receiveMergeBlockMetaReq(TransportClient client,
|
||||||
|
MergedBlockMetaRequest mergedBlockMetaRequest, MergedBlockMetaResponseCallback callback) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,6 +113,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
|
||||||
processStreamRequest((StreamRequest) request);
|
processStreamRequest((StreamRequest) request);
|
||||||
} else if (request instanceof UploadStream) {
|
} else if (request instanceof UploadStream) {
|
||||||
processStreamUpload((UploadStream) request);
|
processStreamUpload((UploadStream) request);
|
||||||
|
} else if (request instanceof MergedBlockMetaRequest) {
|
||||||
|
processMergedBlockMetaRequest((MergedBlockMetaRequest) request);
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException("Unknown request type: " + request);
|
throw new IllegalArgumentException("Unknown request type: " + request);
|
||||||
}
|
}
|
||||||
|
@ -260,6 +262,30 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void processMergedBlockMetaRequest(final MergedBlockMetaRequest req) {
|
||||||
|
try {
|
||||||
|
rpcHandler.getMergedBlockMetaReqHandler().receiveMergeBlockMetaReq(reverseClient, req,
|
||||||
|
new MergedBlockMetaResponseCallback() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onSuccess(int numChunks, ManagedBuffer buffer) {
|
||||||
|
logger.trace("Sending meta for request {} numChunks {}", req, numChunks);
|
||||||
|
respond(new MergedBlockMetaSuccess(req.requestId, numChunks, buffer));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Throwable e) {
|
||||||
|
logger.trace("Failed to send meta for {}", req);
|
||||||
|
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.error("Error while invoking receiveMergeBlockMetaReq() for appId {} shuffleId {} "
|
||||||
|
+ "reduceId {}", req.appId, req.shuffleId, req.appId, e);
|
||||||
|
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Responds to a single message with some Encodable object. If a failure occurs while sending,
|
* Responds to a single message with some Encodable object. If a failure occurs while sending,
|
||||||
* it will be logged and the channel closed.
|
* it will be logged and the channel closed.
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.network;
|
package org.apache.spark.network;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -24,16 +25,19 @@ import io.netty.channel.Channel;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
import org.apache.spark.network.protocol.*;
|
import org.apache.spark.network.protocol.*;
|
||||||
import org.apache.spark.network.server.NoOpRpcHandler;
|
import org.apache.spark.network.server.NoOpRpcHandler;
|
||||||
import org.apache.spark.network.server.OneForOneStreamManager;
|
import org.apache.spark.network.server.OneForOneStreamManager;
|
||||||
import org.apache.spark.network.server.RpcHandler;
|
import org.apache.spark.network.server.RpcHandler;
|
||||||
|
import org.apache.spark.network.server.StreamManager;
|
||||||
import org.apache.spark.network.server.TransportRequestHandler;
|
import org.apache.spark.network.server.TransportRequestHandler;
|
||||||
|
|
||||||
public class TransportRequestHandlerSuite {
|
public class TransportRequestHandlerSuite {
|
||||||
|
@ -109,4 +113,55 @@ public class TransportRequestHandlerSuite {
|
||||||
streamManager.connectionTerminated(channel);
|
streamManager.connectionTerminated(channel);
|
||||||
Assert.assertEquals(0, streamManager.numStreamStates());
|
Assert.assertEquals(0, streamManager.numStreamStates());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void handleMergedBlockMetaRequest() throws Exception {
|
||||||
|
RpcHandler.MergedBlockMetaReqHandler metaHandler = (client, request, callback) -> {
|
||||||
|
if (request.shuffleId != -1 && request.reduceId != -1) {
|
||||||
|
callback.onSuccess(2, mock(ManagedBuffer.class));
|
||||||
|
} else {
|
||||||
|
callback.onFailure(new RuntimeException("empty block"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
RpcHandler rpcHandler = new RpcHandler() {
|
||||||
|
@Override
|
||||||
|
public void receive(
|
||||||
|
TransportClient client,
|
||||||
|
ByteBuffer message,
|
||||||
|
RpcResponseCallback callback) {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StreamManager getStreamManager() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
|
||||||
|
return metaHandler;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Channel channel = mock(Channel.class);
|
||||||
|
List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs = new ArrayList<>();
|
||||||
|
when(channel.writeAndFlush(any())).thenAnswer(invocationOnMock0 -> {
|
||||||
|
Object response = invocationOnMock0.getArguments()[0];
|
||||||
|
ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
|
||||||
|
responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
|
||||||
|
return channelFuture;
|
||||||
|
});
|
||||||
|
|
||||||
|
TransportClient reverseClient = mock(TransportClient.class);
|
||||||
|
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
|
||||||
|
rpcHandler, 2L, null);
|
||||||
|
MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0);
|
||||||
|
requestHandler.handle(validMetaReq);
|
||||||
|
assertEquals(1, responseAndPromisePairs.size());
|
||||||
|
assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess);
|
||||||
|
assertEquals(2,
|
||||||
|
((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks());
|
||||||
|
|
||||||
|
MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1);
|
||||||
|
requestHandler.handle(invalidMetaReq);
|
||||||
|
assertEquals(2, responseAndPromisePairs.size());
|
||||||
|
assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,17 +23,20 @@ import java.nio.ByteBuffer;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
import io.netty.channel.local.LocalChannel;
|
import io.netty.channel.local.LocalChannel;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
import org.apache.spark.network.buffer.NioManagedBuffer;
|
import org.apache.spark.network.buffer.NioManagedBuffer;
|
||||||
import org.apache.spark.network.client.ChunkReceivedCallback;
|
import org.apache.spark.network.client.ChunkReceivedCallback;
|
||||||
|
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
|
||||||
import org.apache.spark.network.client.RpcResponseCallback;
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.StreamCallback;
|
import org.apache.spark.network.client.StreamCallback;
|
||||||
import org.apache.spark.network.client.TransportResponseHandler;
|
import org.apache.spark.network.client.TransportResponseHandler;
|
||||||
import org.apache.spark.network.protocol.ChunkFetchFailure;
|
import org.apache.spark.network.protocol.ChunkFetchFailure;
|
||||||
import org.apache.spark.network.protocol.ChunkFetchSuccess;
|
import org.apache.spark.network.protocol.ChunkFetchSuccess;
|
||||||
|
import org.apache.spark.network.protocol.MergedBlockMetaSuccess;
|
||||||
import org.apache.spark.network.protocol.RpcFailure;
|
import org.apache.spark.network.protocol.RpcFailure;
|
||||||
import org.apache.spark.network.protocol.RpcResponse;
|
import org.apache.spark.network.protocol.RpcResponse;
|
||||||
import org.apache.spark.network.protocol.StreamChunkId;
|
import org.apache.spark.network.protocol.StreamChunkId;
|
||||||
|
@ -167,4 +170,40 @@ public class TransportResponseHandlerSuite {
|
||||||
|
|
||||||
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
|
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void handleSuccessfulMergedBlockMeta() throws Exception {
|
||||||
|
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
|
||||||
|
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
|
||||||
|
handler.addRpcRequest(13, callback);
|
||||||
|
assertEquals(1, handler.numOutstandingRequests());
|
||||||
|
|
||||||
|
// This response should be ignored.
|
||||||
|
handler.handle(new MergedBlockMetaSuccess(22, 2,
|
||||||
|
new NioManagedBuffer(ByteBuffer.allocate(7))));
|
||||||
|
assertEquals(1, handler.numOutstandingRequests());
|
||||||
|
|
||||||
|
ByteBuffer resp = ByteBuffer.allocate(10);
|
||||||
|
handler.handle(new MergedBlockMetaSuccess(13, 2, new NioManagedBuffer(resp)));
|
||||||
|
ArgumentCaptor<NioManagedBuffer> bufferCaptor = ArgumentCaptor.forClass(NioManagedBuffer.class);
|
||||||
|
verify(callback, times(1)).onSuccess(eq(2), bufferCaptor.capture());
|
||||||
|
assertEquals(resp, bufferCaptor.getValue().nioByteBuffer());
|
||||||
|
assertEquals(0, handler.numOutstandingRequests());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void handleFailedMergedBlockMeta() throws Exception {
|
||||||
|
TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
|
||||||
|
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
|
||||||
|
handler.addRpcRequest(51, callback);
|
||||||
|
assertEquals(1, handler.numOutstandingRequests());
|
||||||
|
|
||||||
|
// This response should be ignored.
|
||||||
|
handler.handle(new RpcFailure(6, "failed"));
|
||||||
|
assertEquals(1, handler.numOutstandingRequests());
|
||||||
|
|
||||||
|
handler.handle(new RpcFailure(51, "failed"));
|
||||||
|
verify(callback, times(1)).onFailure(any());
|
||||||
|
assertEquals(0, handler.numOutstandingRequests());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.protocol;
|
||||||
|
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.roaringbitmap.RoaringBitmap;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
|
||||||
|
import org.apache.spark.network.util.ByteArrayWritableChannel;
|
||||||
|
import org.apache.spark.network.util.TransportConf;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test for {@link MergedBlockMetaSuccess}.
|
||||||
|
*/
|
||||||
|
public class MergedBlockMetaSuccessSuite {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMergedBlocksMetaEncodeDecode() throws Exception {
|
||||||
|
File chunkMetaFile = new File("target/mergedBlockMetaTest");
|
||||||
|
Files.deleteIfExists(chunkMetaFile.toPath());
|
||||||
|
RoaringBitmap chunk1 = new RoaringBitmap();
|
||||||
|
chunk1.add(1);
|
||||||
|
chunk1.add(3);
|
||||||
|
RoaringBitmap chunk2 = new RoaringBitmap();
|
||||||
|
chunk2.add(2);
|
||||||
|
chunk2.add(4);
|
||||||
|
RoaringBitmap[] expectedChunks = new RoaringBitmap[]{chunk1, chunk2};
|
||||||
|
try (DataOutputStream metaOutput = new DataOutputStream(new FileOutputStream(chunkMetaFile))) {
|
||||||
|
for (int i = 0; i < expectedChunks.length; i++) {
|
||||||
|
expectedChunks[i].serialize(metaOutput);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TransportConf conf = mock(TransportConf.class);
|
||||||
|
when(conf.lazyFileDescriptor()).thenReturn(false);
|
||||||
|
long requestId = 1L;
|
||||||
|
MergedBlockMetaSuccess expectedMeta = new MergedBlockMetaSuccess(requestId, 2,
|
||||||
|
new FileSegmentManagedBuffer(conf, chunkMetaFile, 0, chunkMetaFile.length()));
|
||||||
|
|
||||||
|
List<Object> out = Lists.newArrayList();
|
||||||
|
ChannelHandlerContext context = mock(ChannelHandlerContext.class);
|
||||||
|
when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT);
|
||||||
|
|
||||||
|
MessageEncoder.INSTANCE.encode(context, expectedMeta, out);
|
||||||
|
Assert.assertEquals(1, out.size());
|
||||||
|
MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0);
|
||||||
|
|
||||||
|
ByteArrayWritableChannel writableChannel =
|
||||||
|
new ByteArrayWritableChannel((int) msgWithHeader.count());
|
||||||
|
while (msgWithHeader.transfered() < msgWithHeader.count()) {
|
||||||
|
msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered());
|
||||||
|
}
|
||||||
|
ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData());
|
||||||
|
messageBuf.readLong(); // frame length
|
||||||
|
MessageDecoder.INSTANCE.decode(mock(ChannelHandlerContext.class), messageBuf, out);
|
||||||
|
Assert.assertEquals(1, out.size());
|
||||||
|
MergedBlockMetaSuccess decoded = (MergedBlockMetaSuccess) out.get(0);
|
||||||
|
Assert.assertEquals("merged block", expectedMeta.requestId, decoded.requestId);
|
||||||
|
Assert.assertEquals("num chunks", expectedMeta.getNumChunks(), decoded.getNumChunks());
|
||||||
|
|
||||||
|
ByteBuf responseBuf = Unpooled.wrappedBuffer(decoded.body().nioByteBuffer());
|
||||||
|
RoaringBitmap[] responseBitmaps = new RoaringBitmap[expectedMeta.getNumChunks()];
|
||||||
|
for (int i = 0; i < expectedMeta.getNumChunks(); i++) {
|
||||||
|
responseBitmaps[i] = Encoders.Bitmaps.decode(responseBuf);
|
||||||
|
}
|
||||||
|
Assert.assertEquals(
|
||||||
|
"num of roaring bitmaps", expectedMeta.getNumChunks(), responseBitmaps.length);
|
||||||
|
for (int i = 0; i < expectedMeta.getNumChunks(); i++) {
|
||||||
|
Assert.assertEquals("chunk bitmap " + i, expectedChunks[i], responseBitmaps[i]);
|
||||||
|
}
|
||||||
|
Files.delete(chunkMetaFile.toPath());
|
||||||
|
}
|
||||||
|
}
|
|
@ -178,4 +178,24 @@ public abstract class BlockStoreClient implements Closeable {
|
||||||
MergeFinalizerListener listener) {
|
MergeFinalizerListener listener) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the meta information of a merged block from the remote shuffle service.
|
||||||
|
*
|
||||||
|
* @param host the host of the remote node.
|
||||||
|
* @param port the port of the remote node.
|
||||||
|
* @param shuffleId shuffle id.
|
||||||
|
* @param reduceId reduce id.
|
||||||
|
* @param listener the listener to receive chunk counts.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public void getMergedBlockMeta(
|
||||||
|
String host,
|
||||||
|
int port,
|
||||||
|
int shuffleId,
|
||||||
|
int reduceId,
|
||||||
|
MergedBlocksMetaListener listener) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
|
|
||||||
package org.apache.spark.network.shuffle;
|
package org.apache.spark.network.shuffle;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
import com.codahale.metrics.Gauge;
|
import com.codahale.metrics.Gauge;
|
||||||
|
@ -31,14 +33,17 @@ import com.codahale.metrics.Metric;
|
||||||
import com.codahale.metrics.MetricSet;
|
import com.codahale.metrics.MetricSet;
|
||||||
import com.codahale.metrics.Timer;
|
import com.codahale.metrics.Timer;
|
||||||
import com.codahale.metrics.Counter;
|
import com.codahale.metrics.Counter;
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import org.apache.spark.network.client.StreamCallbackWithID;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
|
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
|
||||||
import org.apache.spark.network.client.RpcResponseCallback;
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
|
import org.apache.spark.network.client.StreamCallbackWithID;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
|
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
|
||||||
import org.apache.spark.network.server.OneForOneStreamManager;
|
import org.apache.spark.network.server.OneForOneStreamManager;
|
||||||
import org.apache.spark.network.server.RpcHandler;
|
import org.apache.spark.network.server.RpcHandler;
|
||||||
import org.apache.spark.network.server.StreamManager;
|
import org.apache.spark.network.server.StreamManager;
|
||||||
|
@ -54,8 +59,12 @@ import org.apache.spark.network.util.TransportConf;
|
||||||
* Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
|
* Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
|
||||||
* is equivalent to one block.
|
* is equivalent to one block.
|
||||||
*/
|
*/
|
||||||
public class ExternalBlockHandler extends RpcHandler {
|
public class ExternalBlockHandler extends RpcHandler
|
||||||
|
implements RpcHandler.MergedBlockMetaReqHandler {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class);
|
private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class);
|
||||||
|
private static final String SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger";
|
||||||
|
private static final String SHUFFLE_BLOCK_ID = "shuffle";
|
||||||
|
private static final String SHUFFLE_CHUNK_ID = "shuffleChunk";
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
final ExternalShuffleBlockResolver blockManager;
|
final ExternalShuffleBlockResolver blockManager;
|
||||||
|
@ -128,24 +137,23 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
BlockTransferMessage msgObj,
|
BlockTransferMessage msgObj,
|
||||||
TransportClient client,
|
TransportClient client,
|
||||||
RpcResponseCallback callback) {
|
RpcResponseCallback callback) {
|
||||||
if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
|
if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) {
|
||||||
final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
|
final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
|
||||||
try {
|
try {
|
||||||
int numBlockIds;
|
int numBlockIds;
|
||||||
long streamId;
|
long streamId;
|
||||||
if (msgObj instanceof FetchShuffleBlocks) {
|
if (msgObj instanceof AbstractFetchShuffleBlocks) {
|
||||||
FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
|
AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj;
|
||||||
checkAuth(client, msg.appId);
|
checkAuth(client, msg.appId);
|
||||||
numBlockIds = 0;
|
numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks();
|
||||||
if (msg.batchFetchEnabled) {
|
Iterator<ManagedBuffer> iterator;
|
||||||
numBlockIds = msg.mapIds.length;
|
if (msgObj instanceof FetchShuffleBlocks) {
|
||||||
|
iterator = new ShuffleManagedBufferIterator((FetchShuffleBlocks)msgObj);
|
||||||
} else {
|
} else {
|
||||||
for (int[] ids: msg.reduceIds) {
|
iterator = new ShuffleChunkManagedBufferIterator((FetchShuffleBlockChunks) msgObj);
|
||||||
numBlockIds += ids.length;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
streamId = streamManager.registerStream(client.getClientId(),
|
streamId = streamManager.registerStream(client.getClientId(), iterator,
|
||||||
new ShuffleManagedBufferIterator(msg), client.getChannel());
|
client.getChannel());
|
||||||
} else {
|
} else {
|
||||||
// For the compatibility with the old version, still keep the support for OpenBlocks.
|
// For the compatibility with the old version, still keep the support for OpenBlocks.
|
||||||
OpenBlocks msg = (OpenBlocks) msgObj;
|
OpenBlocks msg = (OpenBlocks) msgObj;
|
||||||
|
@ -189,9 +197,14 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
} else if (msgObj instanceof GetLocalDirsForExecutors) {
|
} else if (msgObj instanceof GetLocalDirsForExecutors) {
|
||||||
GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
|
GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
|
||||||
checkAuth(client, msg.appId);
|
checkAuth(client, msg.appId);
|
||||||
Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
|
Set<String> execIdsForBlockResolver = Sets.newHashSet(msg.execIds);
|
||||||
|
boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER);
|
||||||
|
Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId,
|
||||||
|
execIdsForBlockResolver);
|
||||||
|
if (fetchMergedBlockDirs) {
|
||||||
|
localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId));
|
||||||
|
}
|
||||||
callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
|
callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
|
||||||
|
|
||||||
} else if (msgObj instanceof FinalizeShuffleMerge) {
|
} else if (msgObj instanceof FinalizeShuffleMerge) {
|
||||||
final Timer.Context responseDelayContext =
|
final Timer.Context responseDelayContext =
|
||||||
metrics.finalizeShuffleMergeLatencyMillis.time();
|
metrics.finalizeShuffleMergeLatencyMillis.time();
|
||||||
|
@ -211,6 +224,32 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void receiveMergeBlockMetaReq(
|
||||||
|
TransportClient client,
|
||||||
|
MergedBlockMetaRequest metaRequest,
|
||||||
|
MergedBlockMetaResponseCallback callback) {
|
||||||
|
final Timer.Context responseDelayContext = metrics.fetchMergedBlocksMetaLatencyMillis.time();
|
||||||
|
try {
|
||||||
|
checkAuth(client, metaRequest.appId);
|
||||||
|
MergedBlockMeta mergedMeta =
|
||||||
|
mergeManager.getMergedBlockMeta(metaRequest.appId, metaRequest.shuffleId,
|
||||||
|
metaRequest.reduceId);
|
||||||
|
logger.debug(
|
||||||
|
"Merged block chunks appId {} shuffleId {} reduceId {} num-chunks : {} ",
|
||||||
|
metaRequest.appId, metaRequest.shuffleId, metaRequest.reduceId,
|
||||||
|
mergedMeta.getNumChunks());
|
||||||
|
callback.onSuccess(mergedMeta.getNumChunks(), mergedMeta.getChunksBitmapBuffer());
|
||||||
|
} finally {
|
||||||
|
responseDelayContext.stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void exceptionCaught(Throwable cause, TransportClient client) {
|
public void exceptionCaught(Throwable cause, TransportClient client) {
|
||||||
metrics.caughtExceptions.inc();
|
metrics.caughtExceptions.inc();
|
||||||
|
@ -262,6 +301,8 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
private final Timer openBlockRequestLatencyMillis = new Timer();
|
private final Timer openBlockRequestLatencyMillis = new Timer();
|
||||||
// Time latency for executor registration latency in ms
|
// Time latency for executor registration latency in ms
|
||||||
private final Timer registerExecutorRequestLatencyMillis = new Timer();
|
private final Timer registerExecutorRequestLatencyMillis = new Timer();
|
||||||
|
// Time latency for processing fetch merged blocks meta request latency in ms
|
||||||
|
private final Timer fetchMergedBlocksMetaLatencyMillis = new Timer();
|
||||||
// Time latency for processing finalize shuffle merge request latency in ms
|
// Time latency for processing finalize shuffle merge request latency in ms
|
||||||
private final Timer finalizeShuffleMergeLatencyMillis = new Timer();
|
private final Timer finalizeShuffleMergeLatencyMillis = new Timer();
|
||||||
// Block transfer rate in byte per second
|
// Block transfer rate in byte per second
|
||||||
|
@ -275,6 +316,7 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
allMetrics = new HashMap<>();
|
allMetrics = new HashMap<>();
|
||||||
allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
|
allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
|
||||||
allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
|
allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
|
||||||
|
allMetrics.put("fetchMergedBlocksMetaLatencyMillis", fetchMergedBlocksMetaLatencyMillis);
|
||||||
allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis);
|
allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis);
|
||||||
allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
|
allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
|
||||||
allMetrics.put("registeredExecutorsSize",
|
allMetrics.put("registeredExecutorsSize",
|
||||||
|
@ -294,18 +336,26 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
private int index = 0;
|
private int index = 0;
|
||||||
private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
|
private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
|
||||||
private final int size;
|
private final int size;
|
||||||
|
private boolean requestForMergedBlockChunks;
|
||||||
|
|
||||||
ManagedBufferIterator(OpenBlocks msg) {
|
ManagedBufferIterator(OpenBlocks msg) {
|
||||||
String appId = msg.appId;
|
String appId = msg.appId;
|
||||||
String execId = msg.execId;
|
String execId = msg.execId;
|
||||||
String[] blockIds = msg.blockIds;
|
String[] blockIds = msg.blockIds;
|
||||||
String[] blockId0Parts = blockIds[0].split("_");
|
String[] blockId0Parts = blockIds[0].split("_");
|
||||||
if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) {
|
if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_BLOCK_ID)) {
|
||||||
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
|
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
|
||||||
final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
|
final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
|
||||||
size = mapIdAndReduceIds.length;
|
size = mapIdAndReduceIds.length;
|
||||||
blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
|
blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
|
||||||
mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
|
mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
|
||||||
|
} else if (blockId0Parts.length == 4 && blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
|
||||||
|
requestForMergedBlockChunks = true;
|
||||||
|
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
|
||||||
|
final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
|
||||||
|
size = reduceIdAndChunkIds.length;
|
||||||
|
blockDataForIndexFn = index -> mergeManager.getMergedBlockData(msg.appId, shuffleId,
|
||||||
|
reduceIdAndChunkIds[index], reduceIdAndChunkIds[index + 1]);
|
||||||
} else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
|
} else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
|
||||||
final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
|
final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
|
||||||
size = rddAndSplitIds.length;
|
size = rddAndSplitIds.length;
|
||||||
|
@ -330,20 +380,26 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
|
private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
|
||||||
final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
|
// For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds.
|
||||||
|
// For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds.
|
||||||
|
final int[] primaryIdAndSecondaryIds = new int[2 * blockIds.length];
|
||||||
for (int i = 0; i < blockIds.length; i++) {
|
for (int i = 0; i < blockIds.length; i++) {
|
||||||
String[] blockIdParts = blockIds[i].split("_");
|
String[] blockIdParts = blockIds[i].split("_");
|
||||||
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
|
if (blockIdParts.length != 4
|
||||||
|
|| (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_ID))
|
||||||
|
|| (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_ID))) {
|
||||||
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
|
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
|
||||||
}
|
}
|
||||||
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
|
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
|
||||||
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
|
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
|
||||||
", got:" + blockIds[i]);
|
", got:" + blockIds[i]);
|
||||||
}
|
}
|
||||||
mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
|
// For regular blocks, blockIdParts[2] is mapId. For chunks, it is reduceId.
|
||||||
mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
|
primaryIdAndSecondaryIds[2 * i] = Integer.parseInt(blockIdParts[2]);
|
||||||
|
// For regular blocks, blockIdParts[3] is reduceId. For chunks, it is chunkId.
|
||||||
|
primaryIdAndSecondaryIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
|
||||||
}
|
}
|
||||||
return mapIdAndReduceIds;
|
return primaryIdAndSecondaryIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -413,6 +469,47 @@ public class ExternalBlockHandler extends RpcHandler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
|
||||||
|
|
||||||
|
private int reduceIdx = 0;
|
||||||
|
private int chunkIdx = 0;
|
||||||
|
|
||||||
|
private final String appId;
|
||||||
|
private final int shuffleId;
|
||||||
|
private final int[] reduceIds;
|
||||||
|
private final int[][] chunkIds;
|
||||||
|
|
||||||
|
ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
|
||||||
|
appId = msg.appId;
|
||||||
|
shuffleId = msg.shuffleId;
|
||||||
|
reduceIds = msg.reduceIds;
|
||||||
|
chunkIds = msg.chunkIds;
|
||||||
|
// reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
|
||||||
|
// must have non-empty reduceIds and chunkIds, see the checking logic in
|
||||||
|
// OneForOneBlockFetcher.
|
||||||
|
assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ManagedBuffer next() {
|
||||||
|
ManagedBuffer block = Preconditions.checkNotNull(mergeManager.getMergedBlockData(
|
||||||
|
appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]));
|
||||||
|
if (chunkIdx < chunkIds[reduceIdx].length - 1) {
|
||||||
|
chunkIdx += 1;
|
||||||
|
} else {
|
||||||
|
chunkIdx = 0;
|
||||||
|
reduceIdx += 1;
|
||||||
|
}
|
||||||
|
metrics.blockTransferRateBytes.mark(block.size());
|
||||||
|
return block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle
|
* Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle
|
||||||
* is not enabled.
|
* is not enabled.
|
||||||
|
|
|
@ -31,6 +31,7 @@ import com.google.common.collect.Lists;
|
||||||
|
|
||||||
import org.apache.spark.network.TransportContext;
|
import org.apache.spark.network.TransportContext;
|
||||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
|
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
|
||||||
import org.apache.spark.network.client.RpcResponseCallback;
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
import org.apache.spark.network.client.TransportClientBootstrap;
|
import org.apache.spark.network.client.TransportClientBootstrap;
|
||||||
|
@ -187,6 +188,37 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void getMergedBlockMeta(
|
||||||
|
String host,
|
||||||
|
int port,
|
||||||
|
int shuffleId,
|
||||||
|
int reduceId,
|
||||||
|
MergedBlocksMetaListener listener) {
|
||||||
|
checkInit();
|
||||||
|
logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port,
|
||||||
|
shuffleId, reduceId);
|
||||||
|
try {
|
||||||
|
TransportClient client = clientFactory.createClient(host, port);
|
||||||
|
client.sendMergedBlockMetaReq(appId, shuffleId, reduceId,
|
||||||
|
new MergedBlockMetaResponseCallback() {
|
||||||
|
@Override
|
||||||
|
public void onSuccess(int numChunks, ManagedBuffer buffer) {
|
||||||
|
logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}",
|
||||||
|
shuffleId, reduceId);
|
||||||
|
listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Throwable e) {
|
||||||
|
listener.onFailure(shuffleId, reduceId, e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} catch (Exception e) {
|
||||||
|
listener.onFailure(shuffleId, reduceId, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MetricSet shuffleMetrics() {
|
public MetricSet shuffleMetrics() {
|
||||||
checkInit();
|
checkInit();
|
||||||
|
|
|
@ -361,8 +361,8 @@ public class ExternalShuffleBlockResolver {
|
||||||
return numRemovedBlocks;
|
return numRemovedBlocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, String[]> getLocalDirs(String appId, String[] execIds) {
|
public Map<String, String[]> getLocalDirs(String appId, Set<String> execIds) {
|
||||||
return Arrays.stream(execIds)
|
return execIds.stream()
|
||||||
.map(exec -> {
|
.map(exec -> {
|
||||||
ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec));
|
ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec));
|
||||||
if (info == null) {
|
if (info == null) {
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.shuffle;
|
||||||
|
|
||||||
|
import java.util.EventListener;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Listener for receiving success or failure events when fetching meta of merged blocks.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public interface MergedBlocksMetaListener extends EventListener {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called after successfully receiving the meta of a merged block.
|
||||||
|
*
|
||||||
|
* @param shuffleId shuffle id.
|
||||||
|
* @param reduceId reduce id.
|
||||||
|
* @param meta contains meta information of a merged block.
|
||||||
|
*/
|
||||||
|
void onSuccess(int shuffleId, int reduceId, MergedBlockMeta meta);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when there is an exception while fetching the meta of a merged block.
|
||||||
|
*
|
||||||
|
* @param shuffleId shuffle id.
|
||||||
|
* @param reduceId reduce id.
|
||||||
|
* @param exception exception getting chunk counts.
|
||||||
|
*/
|
||||||
|
void onFailure(int shuffleId, int reduceId, Throwable exception);
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
import com.google.common.primitives.Longs;
|
import com.google.common.primitives.Longs;
|
||||||
|
@ -34,8 +35,10 @@ import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.StreamCallback;
|
import org.apache.spark.network.client.StreamCallback;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
import org.apache.spark.network.server.OneForOneStreamManager;
|
import org.apache.spark.network.server.OneForOneStreamManager;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.AbstractFetchShuffleBlocks;
|
||||||
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
||||||
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
|
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
|
||||||
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
||||||
import org.apache.spark.network.shuffle.protocol.StreamHandle;
|
import org.apache.spark.network.shuffle.protocol.StreamHandle;
|
||||||
import org.apache.spark.network.util.TransportConf;
|
import org.apache.spark.network.util.TransportConf;
|
||||||
|
@ -51,6 +54,8 @@ import org.apache.spark.network.util.TransportConf;
|
||||||
*/
|
*/
|
||||||
public class OneForOneBlockFetcher {
|
public class OneForOneBlockFetcher {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
|
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
|
||||||
|
private static final String SHUFFLE_BLOCK_PREFIX = "shuffle_";
|
||||||
|
private static final String SHUFFLE_CHUNK_PREFIX = "shuffleChunk_";
|
||||||
|
|
||||||
private final TransportClient client;
|
private final TransportClient client;
|
||||||
private final BlockTransferMessage message;
|
private final BlockTransferMessage message;
|
||||||
|
@ -88,74 +93,113 @@ public class OneForOneBlockFetcher {
|
||||||
if (blockIds.length == 0) {
|
if (blockIds.length == 0) {
|
||||||
throw new IllegalArgumentException("Zero-sized blockIds array");
|
throw new IllegalArgumentException("Zero-sized blockIds array");
|
||||||
}
|
}
|
||||||
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
|
if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
|
||||||
this.blockIds = new String[blockIds.length];
|
this.blockIds = new String[blockIds.length];
|
||||||
this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
|
this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
|
||||||
} else {
|
} else {
|
||||||
this.blockIds = blockIds;
|
this.blockIds = blockIds;
|
||||||
this.message = new OpenBlocks(appId, execId, blockIds);
|
this.message = new OpenBlocks(appId, execId, blockIds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isShuffleBlocks(String[] blockIds) {
|
/**
|
||||||
for (String blockId : blockIds) {
|
* Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
|
||||||
if (!blockId.startsWith("shuffle_")) {
|
* the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
|
||||||
return false;
|
* IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
|
||||||
}
|
* all unmerged shuffle blocks or all merged shuffle chunks.
|
||||||
|
* @param blockIds block ID array
|
||||||
|
* @return whether the array contains only shuffle block IDs
|
||||||
|
*/
|
||||||
|
private boolean areShuffleBlocksOrChunks(String[] blockIds) {
|
||||||
|
if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
|
||||||
|
// It comes here because there is a blockId which doesn't have "shuffle_" prefix so we
|
||||||
|
// check if all the block ids are shuffle chunk Ids.
|
||||||
|
return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
|
||||||
|
private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
|
||||||
|
String appId,
|
||||||
|
String execId,
|
||||||
|
String[] blockIds) {
|
||||||
|
if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
|
||||||
|
return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
|
||||||
|
} else {
|
||||||
|
return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create FetchShuffleBlocks message and rebuild internal blockIds by
|
* Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
|
||||||
* analyzing the pass in blockIds.
|
* analyzing the passed in blockIds.
|
||||||
*/
|
*/
|
||||||
private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
|
private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
|
||||||
String appId, String execId, String[] blockIds) {
|
String appId,
|
||||||
|
String execId,
|
||||||
|
String[] blockIds,
|
||||||
|
boolean areMergedChunks) {
|
||||||
String[] firstBlock = splitBlockId(blockIds[0]);
|
String[] firstBlock = splitBlockId(blockIds[0]);
|
||||||
int shuffleId = Integer.parseInt(firstBlock[1]);
|
int shuffleId = Integer.parseInt(firstBlock[1]);
|
||||||
boolean batchFetchEnabled = firstBlock.length == 5;
|
boolean batchFetchEnabled = firstBlock.length == 5;
|
||||||
|
|
||||||
LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
|
// In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
|
||||||
|
// is reduceId.
|
||||||
|
LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
|
||||||
for (String blockId : blockIds) {
|
for (String blockId : blockIds) {
|
||||||
String[] blockIdParts = splitBlockId(blockId);
|
String[] blockIdParts = splitBlockId(blockId);
|
||||||
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
|
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
|
||||||
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
|
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
|
||||||
", got:" + blockId);
|
", got:" + blockId);
|
||||||
}
|
}
|
||||||
long mapId = Long.parseLong(blockIdParts[2]);
|
Number primaryId;
|
||||||
if (!mapIdToBlocksInfo.containsKey(mapId)) {
|
if (!areMergedChunks) {
|
||||||
mapIdToBlocksInfo.put(mapId, new BlocksInfo());
|
primaryId = Long.parseLong(blockIdParts[2]);
|
||||||
|
} else {
|
||||||
|
primaryId = Integer.parseInt(blockIdParts[2]);
|
||||||
}
|
}
|
||||||
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
|
BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId,
|
||||||
blocksInfoByMapId.blockIds.add(blockId);
|
id -> new BlocksInfo());
|
||||||
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
|
blocksInfoByPrimaryId.blockIds.add(blockId);
|
||||||
|
// If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a
|
||||||
|
// shuffleChunk block, then blockIdParts[3] = chunkId
|
||||||
|
blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
|
||||||
if (batchFetchEnabled) {
|
if (batchFetchEnabled) {
|
||||||
|
// It comes here only if the blockId is a regular shuffle block not a shuffleChunk block.
|
||||||
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
|
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
|
||||||
// FetchShuffleBlocks to store the start and end reduce id for range
|
// FetchShuffleBlocks to store the start and end reduce id for range
|
||||||
// [startReduceId, endReduceId).
|
// [startReduceId, endReduceId).
|
||||||
assert(blockIdParts.length == 5);
|
assert(blockIdParts.length == 5);
|
||||||
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
|
// blockIdParts[4] is the end reduce id for the batch range
|
||||||
|
blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
|
// In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
|
||||||
int[][] reduceIdArr = new int[mapIds.length][];
|
// secondaryIds are chunkIds.
|
||||||
|
int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
|
||||||
int blockIdIndex = 0;
|
int blockIdIndex = 0;
|
||||||
for (int i = 0; i < mapIds.length; i++) {
|
int secIndex = 0;
|
||||||
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
|
for (BlocksInfo blocksInfo: primaryIdToBlocksInfo.values()) {
|
||||||
reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
|
secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfo.ids);
|
||||||
|
|
||||||
// The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
|
// The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
|
||||||
// because the shuffle data's return order should match the `blockIds`'s order to ensure
|
// FetchShuffleBlockChunks because the shuffle data's return order should match the
|
||||||
// blockId and data match.
|
// `blockIds`'s order to ensure blockId and data match.
|
||||||
for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
|
for (String blockId : blocksInfo.blockIds) {
|
||||||
this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
|
this.blockIds[blockIdIndex++] = blockId;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert(blockIdIndex == this.blockIds.length);
|
assert(blockIdIndex == this.blockIds.length);
|
||||||
|
Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
|
||||||
return new FetchShuffleBlocks(
|
if (!areMergedChunks) {
|
||||||
appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
|
long[] mapIds = Longs.toArray(primaryIds);
|
||||||
|
return new FetchShuffleBlocks(
|
||||||
|
appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled);
|
||||||
|
} else {
|
||||||
|
int[] reduceIds = Ints.toArray(primaryIds);
|
||||||
|
return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
|
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
|
||||||
|
@ -163,7 +207,17 @@ public class OneForOneBlockFetcher {
|
||||||
String[] blockIdParts = blockId.split("_");
|
String[] blockIdParts = blockId.split("_");
|
||||||
// For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
|
// For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
|
||||||
// For single block id, the format contains shuffleId, mapId, educeId.
|
// For single block id, the format contains shuffleId, mapId, educeId.
|
||||||
if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
|
// For single block chunk id, the format contains shuffleId, reduceId, chunkId.
|
||||||
|
if (blockIdParts.length < 4 || blockIdParts.length > 5) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Unexpected shuffle block id format: " + blockId);
|
||||||
|
}
|
||||||
|
if (blockIdParts.length == 5 && !blockIdParts[0].equals("shuffle")) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Unexpected shuffle block id format: " + blockId);
|
||||||
|
}
|
||||||
|
if (blockIdParts.length == 4 &&
|
||||||
|
!(blockIdParts[0].equals("shuffle") || blockIdParts[0].equals("shuffleChunk"))) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Unexpected shuffle block id format: " + blockId);
|
"Unexpected shuffle block id format: " + blockId);
|
||||||
}
|
}
|
||||||
|
@ -173,11 +227,15 @@ public class OneForOneBlockFetcher {
|
||||||
/** The reduceIds and blocks in a single mapId */
|
/** The reduceIds and blocks in a single mapId */
|
||||||
private class BlocksInfo {
|
private class BlocksInfo {
|
||||||
|
|
||||||
final ArrayList<Integer> reduceIds;
|
/**
|
||||||
|
* For {@link FetchShuffleBlocks} message, the ids are reduceIds.
|
||||||
|
* For {@link FetchShuffleBlockChunks} message, the ids are chunkIds.
|
||||||
|
*/
|
||||||
|
final ArrayList<Integer> ids;
|
||||||
final ArrayList<String> blockIds;
|
final ArrayList<String> blockIds;
|
||||||
|
|
||||||
BlocksInfo() {
|
BlocksInfo() {
|
||||||
this.reduceIds = new ArrayList<>();
|
this.ids = new ArrayList<>();
|
||||||
this.blockIds = new ArrayList<>();
|
this.blockIds = new ArrayList<>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.shuffle.protocol;
|
||||||
|
|
||||||
|
import com.google.common.base.Objects;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.builder.ToStringBuilder;
|
||||||
|
import org.apache.commons.lang3.builder.ToStringStyle;
|
||||||
|
import org.apache.spark.network.protocol.Encoders;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for fetch shuffle blocks and chunks.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public abstract class AbstractFetchShuffleBlocks extends BlockTransferMessage {
|
||||||
|
public final String appId;
|
||||||
|
public final String execId;
|
||||||
|
public final int shuffleId;
|
||||||
|
|
||||||
|
protected AbstractFetchShuffleBlocks(
|
||||||
|
String appId,
|
||||||
|
String execId,
|
||||||
|
int shuffleId) {
|
||||||
|
this.appId = appId;
|
||||||
|
this.execId = execId;
|
||||||
|
this.shuffleId = shuffleId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ToStringBuilder toStringHelper() {
|
||||||
|
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
|
||||||
|
.append("appId", appId)
|
||||||
|
.append("execId", execId)
|
||||||
|
.append("shuffleId", shuffleId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns number of blocks in the request.
|
||||||
|
*/
|
||||||
|
public abstract int getNumBlocks();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
AbstractFetchShuffleBlocks that = (AbstractFetchShuffleBlocks) o;
|
||||||
|
return shuffleId == that.shuffleId
|
||||||
|
&& Objects.equal(appId, that.appId) && Objects.equal(execId, that.execId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
int result = appId.hashCode();
|
||||||
|
result = 31 * result + execId.hashCode();
|
||||||
|
result = 31 * result + shuffleId;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int encodedLength() {
|
||||||
|
return Encoders.Strings.encodedLength(appId)
|
||||||
|
+ Encoders.Strings.encodedLength(execId)
|
||||||
|
+ 4; /* encoded length of shuffleId */
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void encode(ByteBuf buf) {
|
||||||
|
Encoders.Strings.encode(buf, appId);
|
||||||
|
Encoders.Strings.encode(buf, execId);
|
||||||
|
buf.writeInt(shuffleId);
|
||||||
|
}
|
||||||
|
}
|
|
@ -48,7 +48,8 @@ public abstract class BlockTransferMessage implements Encodable {
|
||||||
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
|
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
|
||||||
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
|
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
|
||||||
FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11),
|
FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11),
|
||||||
PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14);
|
PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14),
|
||||||
|
FETCH_SHUFFLE_BLOCK_CHUNKS(15);
|
||||||
|
|
||||||
private final byte id;
|
private final byte id;
|
||||||
|
|
||||||
|
@ -82,6 +83,7 @@ public abstract class BlockTransferMessage implements Encodable {
|
||||||
case 12: return PushBlockStream.decode(buf);
|
case 12: return PushBlockStream.decode(buf);
|
||||||
case 13: return FinalizeShuffleMerge.decode(buf);
|
case 13: return FinalizeShuffleMerge.decode(buf);
|
||||||
case 14: return MergeStatuses.decode(buf);
|
case 14: return MergeStatuses.decode(buf);
|
||||||
|
case 15: return FetchShuffleBlockChunks.decode(buf);
|
||||||
default: throw new IllegalArgumentException("Unknown message type: " + type);
|
default: throw new IllegalArgumentException("Unknown message type: " + type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.shuffle.protocol;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
|
||||||
|
import org.apache.spark.network.protocol.Encoders;
|
||||||
|
|
||||||
|
// Needed by ScalaDoc. See SPARK-7726
|
||||||
|
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Request to read a set of block chunks. Returns {@link StreamHandle}.
|
||||||
|
*
|
||||||
|
* @since 3.2.0
|
||||||
|
*/
|
||||||
|
public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
|
||||||
|
// The length of reduceIds must equal to chunkIds.size().
|
||||||
|
public final int[] reduceIds;
|
||||||
|
// The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds.
|
||||||
|
public final int[][] chunkIds;
|
||||||
|
|
||||||
|
public FetchShuffleBlockChunks(
|
||||||
|
String appId,
|
||||||
|
String execId,
|
||||||
|
int shuffleId,
|
||||||
|
int[] reduceIds,
|
||||||
|
int[][] chunkIds) {
|
||||||
|
super(appId, execId, shuffleId);
|
||||||
|
this.reduceIds = reduceIds;
|
||||||
|
this.chunkIds = chunkIds;
|
||||||
|
assert(reduceIds.length == chunkIds.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Type type() { return Type.FETCH_SHUFFLE_BLOCK_CHUNKS; }
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return toStringHelper()
|
||||||
|
.append("reduceIds", Arrays.toString(reduceIds))
|
||||||
|
.append("chunkIds", Arrays.deepToString(chunkIds))
|
||||||
|
.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
|
||||||
|
FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o;
|
||||||
|
if (!super.equals(that)) return false;
|
||||||
|
if (!Arrays.equals(reduceIds, that.reduceIds)) return false;
|
||||||
|
return Arrays.deepEquals(chunkIds, that.chunkIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
int result = super.hashCode();
|
||||||
|
result = 31 * result + Arrays.hashCode(reduceIds);
|
||||||
|
result = 31 * result + Arrays.deepHashCode(chunkIds);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int encodedLength() {
|
||||||
|
int encodedLengthOfChunkIds = 0;
|
||||||
|
for (int[] ids: chunkIds) {
|
||||||
|
encodedLengthOfChunkIds += Encoders.IntArrays.encodedLength(ids);
|
||||||
|
}
|
||||||
|
return super.encodedLength()
|
||||||
|
+ Encoders.IntArrays.encodedLength(reduceIds)
|
||||||
|
+ 4 /* encoded length of chunkIds.size() */
|
||||||
|
+ encodedLengthOfChunkIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void encode(ByteBuf buf) {
|
||||||
|
super.encode(buf);
|
||||||
|
Encoders.IntArrays.encode(buf, reduceIds);
|
||||||
|
// Even though reduceIds.length == chunkIds.length, we are explicitly setting the length in the
|
||||||
|
// interest of forward compatibility.
|
||||||
|
buf.writeInt(chunkIds.length);
|
||||||
|
for (int[] ids: chunkIds) {
|
||||||
|
Encoders.IntArrays.encode(buf, ids);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumBlocks() {
|
||||||
|
int numBlocks = 0;
|
||||||
|
for (int[] ids : chunkIds) {
|
||||||
|
numBlocks += ids.length;
|
||||||
|
}
|
||||||
|
return numBlocks;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static FetchShuffleBlockChunks decode(ByteBuf buf) {
|
||||||
|
String appId = Encoders.Strings.decode(buf);
|
||||||
|
String execId = Encoders.Strings.decode(buf);
|
||||||
|
int shuffleId = buf.readInt();
|
||||||
|
int[] reduceIds = Encoders.IntArrays.decode(buf);
|
||||||
|
int chunkIdsLen = buf.readInt();
|
||||||
|
int[][] chunkIds = new int[chunkIdsLen][];
|
||||||
|
for (int i = 0; i < chunkIdsLen; i++) {
|
||||||
|
chunkIds[i] = Encoders.IntArrays.decode(buf);
|
||||||
|
}
|
||||||
|
return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, chunkIds);
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,8 +20,6 @@ package org.apache.spark.network.shuffle.protocol;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import org.apache.commons.lang3.builder.ToStringBuilder;
|
|
||||||
import org.apache.commons.lang3.builder.ToStringStyle;
|
|
||||||
|
|
||||||
import org.apache.spark.network.protocol.Encoders;
|
import org.apache.spark.network.protocol.Encoders;
|
||||||
|
|
||||||
|
@ -29,10 +27,7 @@ import org.apache.spark.network.protocol.Encoders;
|
||||||
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
|
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
|
||||||
|
|
||||||
/** Request to read a set of blocks. Returns {@link StreamHandle}. */
|
/** Request to read a set of blocks. Returns {@link StreamHandle}. */
|
||||||
public class FetchShuffleBlocks extends BlockTransferMessage {
|
public class FetchShuffleBlocks extends AbstractFetchShuffleBlocks {
|
||||||
public final String appId;
|
|
||||||
public final String execId;
|
|
||||||
public final int shuffleId;
|
|
||||||
// The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds,
|
// The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds,
|
||||||
// it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id.
|
// it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id.
|
||||||
public final long[] mapIds;
|
public final long[] mapIds;
|
||||||
|
@ -50,9 +45,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
|
||||||
long[] mapIds,
|
long[] mapIds,
|
||||||
int[][] reduceIds,
|
int[][] reduceIds,
|
||||||
boolean batchFetchEnabled) {
|
boolean batchFetchEnabled) {
|
||||||
this.appId = appId;
|
super(appId, execId, shuffleId);
|
||||||
this.execId = execId;
|
|
||||||
this.shuffleId = shuffleId;
|
|
||||||
this.mapIds = mapIds;
|
this.mapIds = mapIds;
|
||||||
this.reduceIds = reduceIds;
|
this.reduceIds = reduceIds;
|
||||||
assert(mapIds.length == reduceIds.length);
|
assert(mapIds.length == reduceIds.length);
|
||||||
|
@ -69,10 +62,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
|
return toStringHelper()
|
||||||
.append("appId", appId)
|
|
||||||
.append("execId", execId)
|
|
||||||
.append("shuffleId", shuffleId)
|
|
||||||
.append("mapIds", Arrays.toString(mapIds))
|
.append("mapIds", Arrays.toString(mapIds))
|
||||||
.append("reduceIds", Arrays.deepToString(reduceIds))
|
.append("reduceIds", Arrays.deepToString(reduceIds))
|
||||||
.append("batchFetchEnabled", batchFetchEnabled)
|
.append("batchFetchEnabled", batchFetchEnabled)
|
||||||
|
@ -85,35 +75,40 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
|
||||||
FetchShuffleBlocks that = (FetchShuffleBlocks) o;
|
FetchShuffleBlocks that = (FetchShuffleBlocks) o;
|
||||||
|
if (!super.equals(that)) return false;
|
||||||
if (shuffleId != that.shuffleId) return false;
|
|
||||||
if (batchFetchEnabled != that.batchFetchEnabled) return false;
|
if (batchFetchEnabled != that.batchFetchEnabled) return false;
|
||||||
if (!appId.equals(that.appId)) return false;
|
|
||||||
if (!execId.equals(that.execId)) return false;
|
|
||||||
if (!Arrays.equals(mapIds, that.mapIds)) return false;
|
if (!Arrays.equals(mapIds, that.mapIds)) return false;
|
||||||
return Arrays.deepEquals(reduceIds, that.reduceIds);
|
return Arrays.deepEquals(reduceIds, that.reduceIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
int result = appId.hashCode();
|
int result = super.hashCode();
|
||||||
result = 31 * result + execId.hashCode();
|
|
||||||
result = 31 * result + shuffleId;
|
|
||||||
result = 31 * result + Arrays.hashCode(mapIds);
|
result = 31 * result + Arrays.hashCode(mapIds);
|
||||||
result = 31 * result + Arrays.deepHashCode(reduceIds);
|
result = 31 * result + Arrays.deepHashCode(reduceIds);
|
||||||
result = 31 * result + (batchFetchEnabled ? 1 : 0);
|
result = 31 * result + (batchFetchEnabled ? 1 : 0);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumBlocks() {
|
||||||
|
if (batchFetchEnabled) {
|
||||||
|
return mapIds.length;
|
||||||
|
}
|
||||||
|
int numBlocks = 0;
|
||||||
|
for (int[] ids : reduceIds) {
|
||||||
|
numBlocks += ids.length;
|
||||||
|
}
|
||||||
|
return numBlocks;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int encodedLength() {
|
public int encodedLength() {
|
||||||
int encodedLengthOfReduceIds = 0;
|
int encodedLengthOfReduceIds = 0;
|
||||||
for (int[] ids: reduceIds) {
|
for (int[] ids: reduceIds) {
|
||||||
encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids);
|
encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids);
|
||||||
}
|
}
|
||||||
return Encoders.Strings.encodedLength(appId)
|
return super.encodedLength()
|
||||||
+ Encoders.Strings.encodedLength(execId)
|
|
||||||
+ 4 /* encoded length of shuffleId */
|
|
||||||
+ Encoders.LongArrays.encodedLength(mapIds)
|
+ Encoders.LongArrays.encodedLength(mapIds)
|
||||||
+ 4 /* encoded length of reduceIds.size() */
|
+ 4 /* encoded length of reduceIds.size() */
|
||||||
+ encodedLengthOfReduceIds
|
+ encodedLengthOfReduceIds
|
||||||
|
@ -122,9 +117,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void encode(ByteBuf buf) {
|
public void encode(ByteBuf buf) {
|
||||||
Encoders.Strings.encode(buf, appId);
|
super.encode(buf);
|
||||||
Encoders.Strings.encode(buf, execId);
|
|
||||||
buf.writeInt(shuffleId);
|
|
||||||
Encoders.LongArrays.encode(buf, mapIds);
|
Encoders.LongArrays.encode(buf, mapIds);
|
||||||
buf.writeInt(reduceIds.length);
|
buf.writeInt(reduceIds.length);
|
||||||
for (int[] ids: reduceIds) {
|
for (int[] ids: reduceIds) {
|
||||||
|
|
|
@ -34,13 +34,16 @@ import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
import org.apache.spark.network.buffer.NioManagedBuffer;
|
import org.apache.spark.network.buffer.NioManagedBuffer;
|
||||||
|
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
|
||||||
import org.apache.spark.network.client.RpcResponseCallback;
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
|
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
|
||||||
import org.apache.spark.network.server.OneForOneStreamManager;
|
import org.apache.spark.network.server.OneForOneStreamManager;
|
||||||
import org.apache.spark.network.server.RpcHandler;
|
import org.apache.spark.network.server.RpcHandler;
|
||||||
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
||||||
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
|
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
|
||||||
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
|
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
|
||||||
import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
|
import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
|
||||||
import org.apache.spark.network.shuffle.protocol.MergeStatuses;
|
import org.apache.spark.network.shuffle.protocol.MergeStatuses;
|
||||||
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
||||||
|
@ -258,4 +261,113 @@ public class ExternalBlockHandlerSuite {
|
||||||
.get("finalizeShuffleMergeLatencyMillis");
|
.get("finalizeShuffleMergeLatencyMillis");
|
||||||
assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount());
|
assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFetchMergedBlocksMeta() {
|
||||||
|
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn(
|
||||||
|
new MergedBlockMeta(1, mock(ManagedBuffer.class)));
|
||||||
|
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn(
|
||||||
|
new MergedBlockMeta(3, mock(ManagedBuffer.class)));
|
||||||
|
when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn(
|
||||||
|
new MergedBlockMeta(5, mock(ManagedBuffer.class)));
|
||||||
|
|
||||||
|
int[] expectedCount = new int[]{1, 3, 5};
|
||||||
|
String appId = "app0";
|
||||||
|
long requestId = 0L;
|
||||||
|
for (int reduceId = 0; reduceId < 3; reduceId++) {
|
||||||
|
MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId);
|
||||||
|
MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class);
|
||||||
|
handler.getMergedBlockMetaReqHandler()
|
||||||
|
.receiveMergeBlockMetaReq(client, req, callback);
|
||||||
|
verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId);
|
||||||
|
|
||||||
|
ArgumentCaptor<Integer> numChunksResponse = ArgumentCaptor.forClass(Integer.class);
|
||||||
|
ArgumentCaptor<ManagedBuffer> chunkBitmapResponse =
|
||||||
|
ArgumentCaptor.forClass(ManagedBuffer.class);
|
||||||
|
verify(callback, times(1)).onSuccess(numChunksResponse.capture(),
|
||||||
|
chunkBitmapResponse.capture());
|
||||||
|
assertEquals("num chunks in merged block " + reduceId, expectedCount[reduceId],
|
||||||
|
numChunksResponse.getValue().intValue());
|
||||||
|
assertNotNull("chunks bitmap buffer " + reduceId, chunkBitmapResponse.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOpenBlocksWithShuffleChunks() {
|
||||||
|
verifyBlockChunkFetches(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFetchShuffleChunks() {
|
||||||
|
verifyBlockChunkFetches(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifyBlockChunkFetches(boolean useOpenBlocks) {
|
||||||
|
RpcResponseCallback callback = mock(RpcResponseCallback.class);
|
||||||
|
ByteBuffer buffer;
|
||||||
|
if (useOpenBlocks) {
|
||||||
|
OpenBlocks openBlocks =
|
||||||
|
new OpenBlocks("app0", "exec1",
|
||||||
|
new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0",
|
||||||
|
"shuffleChunk_0_1_1"});
|
||||||
|
buffer = openBlocks.toByteBuffer();
|
||||||
|
} else {
|
||||||
|
FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks(
|
||||||
|
"app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}});
|
||||||
|
buffer = fetchChunks.toByteBuffer();
|
||||||
|
}
|
||||||
|
ManagedBuffer[][] buffers = new ManagedBuffer[][] {
|
||||||
|
{
|
||||||
|
new NioManagedBuffer(ByteBuffer.wrap(new byte[5])),
|
||||||
|
new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
|
||||||
|
},
|
||||||
|
{
|
||||||
|
new NioManagedBuffer(ByteBuffer.wrap(new byte[5])),
|
||||||
|
new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (int reduceId = 0; reduceId < 2; reduceId++) {
|
||||||
|
for (int chunkId = 0; chunkId < 2; chunkId++) {
|
||||||
|
when(mergedShuffleManager.getMergedBlockData(
|
||||||
|
"app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler.receive(client, buffer, callback);
|
||||||
|
ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
|
||||||
|
verify(callback, times(1)).onSuccess(response.capture());
|
||||||
|
verify(callback, never()).onFailure(any());
|
||||||
|
StreamHandle handle =
|
||||||
|
(StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
|
||||||
|
assertEquals(4, handle.numChunks);
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
|
||||||
|
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
|
||||||
|
verify(streamManager, times(1)).registerStream(any(), stream.capture(), any());
|
||||||
|
Iterator<ManagedBuffer> bufferIter = stream.getValue();
|
||||||
|
for (int reduceId = 0; reduceId < 2; reduceId++) {
|
||||||
|
for (int chunkId = 0; chunkId < 2; chunkId++) {
|
||||||
|
assertEquals(buffers[reduceId][chunkId], bufferIter.next());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertFalse(bufferIter.hasNext());
|
||||||
|
verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt());
|
||||||
|
verify(blockResolver, never()).getBlockData(
|
||||||
|
anyString(), anyString(), anyInt(), anyInt(), anyInt());
|
||||||
|
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0);
|
||||||
|
verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1);
|
||||||
|
|
||||||
|
// Verify open block request latency metrics
|
||||||
|
Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
|
||||||
|
.getAllMetrics()
|
||||||
|
.getMetrics()
|
||||||
|
.get("openBlockRequestLatencyMillis");
|
||||||
|
assertEquals(1, openBlockRequestLatencyMillis.getCount());
|
||||||
|
// Verify block transfer metrics
|
||||||
|
Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler)
|
||||||
|
.getAllMetrics()
|
||||||
|
.getMetrics()
|
||||||
|
.get("blockTransferRateBytes");
|
||||||
|
assertEquals(24, blockTransferRateBytes.getCount());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,8 +27,7 @@ import com.google.common.collect.Maps;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyInt;
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
import static org.mockito.ArgumentMatchers.anyLong;
|
import static org.mockito.ArgumentMatchers.anyLong;
|
||||||
|
@ -46,6 +45,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
||||||
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
|
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
|
||||||
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
||||||
import org.apache.spark.network.shuffle.protocol.StreamHandle;
|
import org.apache.spark.network.shuffle.protocol.StreamHandle;
|
||||||
import org.apache.spark.network.util.MapConfigProvider;
|
import org.apache.spark.network.util.MapConfigProvider;
|
||||||
|
@ -243,6 +243,57 @@ public class OneForOneBlockFetcherSuite {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testShuffleBlockChunksFetch() {
|
||||||
|
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
|
||||||
|
blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
|
||||||
|
blocks.put("shuffleChunk_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
|
||||||
|
blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
|
||||||
|
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
|
||||||
|
|
||||||
|
BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
|
||||||
|
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
|
||||||
|
new int[][] {{ 0, 1, 2 }}), conf);
|
||||||
|
for (int i = 0; i < 3; i ++) {
|
||||||
|
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_" + i,
|
||||||
|
blocks.get("shuffleChunk_0_0_" + i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testShuffleBlockChunkFetchFailure() {
|
||||||
|
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
|
||||||
|
blocks.put("shuffleChunk_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
|
||||||
|
blocks.put("shuffleChunk_0_0_1", null);
|
||||||
|
blocks.put("shuffleChunk_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
|
||||||
|
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
|
||||||
|
|
||||||
|
BlockFetchingListener listener = fetchBlocks(blocks, blockIds,
|
||||||
|
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[]{0}, new int[][]{{0, 1, 2}}),
|
||||||
|
conf);
|
||||||
|
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_0",
|
||||||
|
blocks.get("shuffleChunk_0_0_0"));
|
||||||
|
verify(listener, times(1)).onBlockFetchFailure(eq("shuffleChunk_0_0_1"), any());
|
||||||
|
verify(listener, times(1)).onBlockFetchSuccess("shuffleChunk_0_0_2",
|
||||||
|
blocks.get("shuffleChunk_0_0_2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidShuffleBlockIds() {
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
|
||||||
|
new String[]{"shuffle_0_0"},
|
||||||
|
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
|
||||||
|
new int[][] {{ 0 }}), conf));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
|
||||||
|
new String[]{"shuffleChunk_0_0_0_0_0"},
|
||||||
|
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
|
||||||
|
new int[][] {{ 0 }}), conf));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> fetchBlocks(new LinkedHashMap<>(),
|
||||||
|
new String[]{"shuffleChunk_0_0_0_0"},
|
||||||
|
new FetchShuffleBlockChunks("app-id", "exec-id", 0, new int[] { 0 },
|
||||||
|
new int[][] {{ 0 }}), conf));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
|
* Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
|
||||||
* simply returns the given (BlockId, Block) pairs.
|
* simply returns the given (BlockId, Block) pairs.
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.shuffle.protocol;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class FetchShuffleBlockChunksSuite {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFetchShuffleBlockChunksEncodeDecode() {
|
||||||
|
FetchShuffleBlockChunks shuffleBlockChunks =
|
||||||
|
new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}});
|
||||||
|
Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks());
|
||||||
|
int len = shuffleBlockChunks.encodedLength();
|
||||||
|
Assert.assertEquals(45, len);
|
||||||
|
ByteBuf buf = Unpooled.buffer(len);
|
||||||
|
shuffleBlockChunks.encode(buf);
|
||||||
|
|
||||||
|
FetchShuffleBlockChunks decoded = FetchShuffleBlockChunks.decode(buf);
|
||||||
|
assertEquals(shuffleBlockChunks, decoded);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.shuffle.protocol;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class FetchShuffleBlocksSuite {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFetchShuffleBlockEncodeDecode() {
|
||||||
|
FetchShuffleBlocks fetchShuffleBlocks =
|
||||||
|
new FetchShuffleBlocks("app0", "exec1", 0, new long[] {0}, new int[][] {{0, 1}}, false);
|
||||||
|
Assert.assertEquals(2, fetchShuffleBlocks.getNumBlocks());
|
||||||
|
int len = fetchShuffleBlocks.encodedLength();
|
||||||
|
Assert.assertEquals(50, len);
|
||||||
|
ByteBuf buf = Unpooled.buffer(len);
|
||||||
|
fetchShuffleBlocks.encode(buf);
|
||||||
|
|
||||||
|
FetchShuffleBlocks decoded = FetchShuffleBlocks.decode(buf);
|
||||||
|
assertEquals(fetchShuffleBlocks, decoded);
|
||||||
|
}
|
||||||
|
}
|
|
@ -62,7 +62,8 @@ class ExternalShuffleServiceMetricsSuite extends SparkFunSuite {
|
||||||
"registerExecutorRequestLatencyMillis",
|
"registerExecutorRequestLatencyMillis",
|
||||||
"shuffle-server.usedDirectMemory",
|
"shuffle-server.usedDirectMemory",
|
||||||
"shuffle-server.usedHeapMemory",
|
"shuffle-server.usedHeapMemory",
|
||||||
"finalizeShuffleMergeLatencyMillis")
|
"finalizeShuffleMergeLatencyMillis",
|
||||||
|
"fetchMergedBlocksMetaLatencyMillis")
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers {
|
||||||
val allMetrics = Set(
|
val allMetrics = Set(
|
||||||
"openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis",
|
"openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis",
|
||||||
"blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections",
|
"blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections",
|
||||||
"numCaughtExceptions", "finalizeShuffleMergeLatencyMillis")
|
"numCaughtExceptions", "finalizeShuffleMergeLatencyMillis",
|
||||||
|
"fetchMergedBlocksMetaLatencyMillis")
|
||||||
|
|
||||||
metrics.getMetrics.keySet().asScala should be (allMetrics)
|
metrics.getMetrics.keySet().asScala should be (allMetrics)
|
||||||
}
|
}
|
||||||
|
|
|
@ -412,7 +412,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
|
||||||
"registerExecutorRequestLatencyMillis",
|
"registerExecutorRequestLatencyMillis",
|
||||||
"finalizeShuffleMergeLatencyMillis",
|
"finalizeShuffleMergeLatencyMillis",
|
||||||
"shuffle-server.usedDirectMemory",
|
"shuffle-server.usedDirectMemory",
|
||||||
"shuffle-server.usedHeapMemory"
|
"shuffle-server.usedHeapMemory",
|
||||||
|
"fetchMergedBlocksMetaLatencyMillis"
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue