diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index ae91bc9cfd..480b52652d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -21,6 +21,8 @@ import java.util.ArrayList; import java.util.List; import io.netty.channel.Channel; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; @@ -32,11 +34,13 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.server.ChunkFetchRequestHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.util.TransportFrameDecoder; @@ -61,6 +65,7 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; private final boolean closeIdleConnections; + private final boolean isClientOnly; /** * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created @@ -77,17 +82,54 @@ public class TransportContext { private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; + // Separate thread pool for handling ChunkFetchRequest. This helps to enable throttling + // max number of TransportServer worker threads that are blocked on writing response + // of ChunkFetchRequest message back to the client via the underlying channel. + private static EventLoopGroup chunkFetchWorkers; + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { - this(conf, rpcHandler, false); + this(conf, rpcHandler, false, false); } public TransportContext( TransportConf conf, RpcHandler rpcHandler, boolean closeIdleConnections) { + this(conf, rpcHandler, closeIdleConnections, false); + } + + /** + * Enables TransportContext initialization for underlying client and server. + * + * @param conf TransportConf + * @param rpcHandler RpcHandler responsible for handling requests and responses. + * @param closeIdleConnections Close idle connections if it is set to true. + * @param isClientOnly This config indicates the TransportContext is only used by a client. + * This config is more important when external shuffle is enabled. + * It stops creating extra event loop and subsequent thread pool + * for shuffle clients to handle chunked fetch requests. + */ + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections, + boolean isClientOnly) { this.conf = conf; this.rpcHandler = rpcHandler; this.closeIdleConnections = closeIdleConnections; + this.isClientOnly = isClientOnly; + + synchronized(TransportContext.class) { + if (chunkFetchWorkers == null && + conf.getModuleName() != null && + conf.getModuleName().equalsIgnoreCase("shuffle") && + !isClientOnly) { + chunkFetchWorkers = NettyUtils.createEventLoop( + IOMode.valueOf(conf.ioMode()), + conf.chunkFetchHandlerThreads(), + "shuffle-chunk-fetch-handler"); + } + } } /** @@ -144,14 +186,23 @@ public class TransportContext { RpcHandler channelRpcHandler) { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); - channel.pipeline() + ChunkFetchRequestHandler chunkFetchHandler = + createChunkFetchHandler(channelHandler, channelRpcHandler); + ChannelPipeline pipeline = channel.pipeline() .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) .addLast("decoder", DECODER) - .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) + .addLast("idleStateHandler", + new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); + // Use a separate EventLoopGroup to handle ChunkFetchRequest messages for shuffle rpcs. + if (conf.getModuleName() != null && + conf.getModuleName().equalsIgnoreCase("shuffle") + && !isClientOnly) { + pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", chunkFetchHandler); + } return channelHandler; } catch (RuntimeException e) { logger.error("Error while initializing Netty pipeline", e); @@ -173,5 +224,14 @@ public class TransportContext { conf.connectionTimeoutMs(), closeIdleConnections); } + /** + * Creates the dedicated ChannelHandler for ChunkFetchRequest messages. + */ + private ChunkFetchRequestHandler createChunkFetchHandler(TransportChannelHandler channelHandler, + RpcHandler rpcHandler) { + return new ChunkFetchRequestHandler(channelHandler.getClient(), + rpcHandler.getStreamManager(), conf.maxChunksBeingTransferred()); + } + public TransportConf getConf() { return conf; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java new file mode 100644 index 0000000000..f08d8b0f98 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.net.SocketAddress; + +import com.google.common.base.Throwables; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; + +import static org.apache.spark.network.util.NettyUtils.*; + +/** + * A dedicated ChannelHandler for processing ChunkFetchRequest messages. When sending response + * of ChunkFetchRequest messages to the clients, the thread performing the I/O on the underlying + * channel could potentially be blocked due to disk contentions. If several hundreds of clients + * send ChunkFetchRequest to the server at the same time, it could potentially occupying all + * threads from TransportServer's default EventLoopGroup for waiting for disk reads before it + * can send the block data back to the client as part of the ChunkFetchSuccess messages. As a + * result, it would leave no threads left to process other RPC messages, which takes much less + * time to process, and could lead to client timing out on either performing SASL authentication, + * registering executors, or waiting for response for an OpenBlocks messages. + */ +public class ChunkFetchRequestHandler extends SimpleChannelInboundHandler { + private static final Logger logger = LoggerFactory.getLogger(ChunkFetchRequestHandler.class); + + private final TransportClient client; + private final StreamManager streamManager; + /** The max number of chunks being transferred and not finished yet. */ + private final long maxChunksBeingTransferred; + + public ChunkFetchRequestHandler( + TransportClient client, + StreamManager streamManager, + Long maxChunksBeingTransferred) { + this.client = client; + this.streamManager = streamManager; + this.maxChunksBeingTransferred = maxChunksBeingTransferred; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + getRemoteAddress(ctx.channel()), cause); + ctx.close(); + } + + @Override + protected void channelRead0( + ChannelHandlerContext ctx, + final ChunkFetchRequest msg) throws Exception { + Channel channel = ctx.channel(); + if (logger.isTraceEnabled()) { + logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), + msg.streamChunkId); + } + long chunksBeingTransferred = streamManager.chunksBeingTransferred(); + if (chunksBeingTransferred >= maxChunksBeingTransferred) { + logger.warn("The number of chunks being transferred {} is above {}, close the connection.", + chunksBeingTransferred, maxChunksBeingTransferred); + channel.close(); + return; + } + ManagedBuffer buf; + try { + streamManager.checkAuthorization(client, msg.streamChunkId.streamId); + streamManager.registerChannel(channel, msg.streamChunkId.streamId); + buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format("Error opening block %s for request from %s", + msg.streamChunkId, getRemoteAddress(channel)), e); + respond(channel, new ChunkFetchFailure(msg.streamChunkId, + Throwables.getStackTraceAsString(e))); + return; + } + + streamManager.chunkBeingSent(msg.streamChunkId.streamId); + respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } + + /** + * The invocation to channel.writeAndFlush is async, and the actual I/O on the + * channel will be handled by the EventLoop the channel is registered to. So even + * though we are processing the ChunkFetchRequest in a separate thread pool, the actual I/O, + * which is the potentially blocking call that could deplete server handler threads, is still + * being processed by TransportServer's default EventLoopGroup. In order to throttle the max + * number of threads that channel I/O for sending response to ChunkFetchRequest, the thread + * calling channel.writeAndFlush will wait for the completion of sending response back to + * client by invoking await(). This will throttle the rate at which threads from + * ChunkFetchRequest dedicated EventLoopGroup submit channel I/O requests to TransportServer's + * default EventLoopGroup, thus making sure that we can reserve some threads in + * TransportServer's default EventLoopGroup for handling other RPC messages. + */ + private ChannelFuture respond( + final Channel channel, + final Encodable result) throws InterruptedException { + final SocketAddress remoteAddress = channel.remoteAddress(); + return channel.writeAndFlush(result).await().addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + logger.trace("Sent result {} to client {}", result, remoteAddress); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + channel.close(); + } + }); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 56782a8327..c824a7b0d4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -18,7 +18,7 @@ package org.apache.spark.network.server; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; @@ -26,6 +26,8 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -47,7 +49,7 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * timeout if the client is continuously sending but getting no responses, for simplicity. */ -public class TransportChannelHandler extends ChannelInboundHandlerAdapter { +public class TransportChannelHandler extends SimpleChannelInboundHandler { private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; @@ -112,8 +114,21 @@ public class TransportChannelHandler extends ChannelInboundHandlerAdapter { super.channelInactive(ctx); } + /** + * Overwrite acceptInboundMessage to properly delegate ChunkFetchRequest messages + * to ChunkFetchRequestHandler. + */ @Override - public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception { + public boolean acceptInboundMessage(Object msg) throws Exception { + if (msg instanceof ChunkFetchRequest) { + return false; + } else { + return super.acceptInboundMessage(msg); + } + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); } else if (request instanceof ResponseMessage) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 9fac96dbe4..3e089b4cae 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -24,6 +24,7 @@ import java.nio.ByteBuffer; import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -97,9 +98,7 @@ public class TransportRequestHandler extends MessageHandler { @Override public void handle(RequestMessage request) { - if (request instanceof ChunkFetchRequest) { - processFetchRequest((ChunkFetchRequest) request); - } else if (request instanceof RpcRequest) { + if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); } else if (request instanceof OneWayMessage) { processOneWayMessage((OneWayMessage) request); @@ -112,36 +111,6 @@ public class TransportRequestHandler extends MessageHandler { } } - private void processFetchRequest(final ChunkFetchRequest req) { - if (logger.isTraceEnabled()) { - logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), - req.streamChunkId); - } - long chunksBeingTransferred = streamManager.chunksBeingTransferred(); - if (chunksBeingTransferred >= maxChunksBeingTransferred) { - logger.warn("The number of chunks being transferred {} is above {}, close the connection.", - chunksBeingTransferred, maxChunksBeingTransferred); - channel.close(); - return; - } - ManagedBuffer buf; - try { - streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); - streamManager.registerChannel(channel, req.streamChunkId.streamId); - buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); - } catch (Exception e) { - logger.error(String.format("Error opening block %s for request from %s", - req.streamChunkId, getRemoteAddress(channel)), e); - respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); - return; - } - - streamManager.chunkBeingSent(req.streamChunkId.streamId); - respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> { - streamManager.chunkSent(req.streamChunkId.streamId); - }); - } - private void processStreamRequest(final StreamRequest req) { if (logger.isTraceEnabled()) { logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel), diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 34e4bb5912..6d5cccd20b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -21,6 +21,7 @@ import java.util.Locale; import java.util.Properties; import com.google.common.primitives.Ints; +import io.netty.util.NettyRuntime; /** * A central location that tracks all the settings we expose to users. @@ -281,4 +282,31 @@ public class TransportConf { public long maxChunksBeingTransferred() { return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); } + + /** + * Percentage of io.serverThreads used by netty to process ChunkFetchRequest. + * Shuffle server will use a separate EventLoopGroup to process ChunkFetchRequest messages. + * Although when calling the async writeAndFlush on the underlying channel to send + * response back to client, the I/O on the channel is still being handled by + * {@link org.apache.spark.network.server.TransportServer}'s default EventLoopGroup + * that's registered with the Channel, by waiting inside the ChunkFetchRequest handler + * threads for the completion of sending back responses, we are able to put a limit on + * the max number of threads from TransportServer's default EventLoopGroup that are + * going to be consumed by writing response to ChunkFetchRequest, which are I/O intensive + * and could take long time to process due to disk contentions. By configuring a slightly + * higher number of shuffler server threads, we are able to reserve some threads for + * handling other RPC messages, thus making the Client less likely to experience timeout + * when sending RPC messages to the shuffle server. Default to 0, which is 2*#cores + * or io.serverThreads. 90 would mean 90% of 2*#cores or 90% of io.serverThreads + * which equals 0.9 * 2*#cores or 0.9 * io.serverThreads. + */ + public int chunkFetchHandlerThreads() { + if (!this.getModuleName().equalsIgnoreCase("shuffle")) { + return 0; + } + int chunkFetchHandlerThreadsPercent = + conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 0); + return this.serverThreads() > 0 ? (this.serverThreads() * chunkFetchHandlerThreadsPercent)/100: + (2 * NettyRuntime.availableProcessors() * chunkFetchHandlerThreadsPercent)/100; + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java new file mode 100644 index 0000000000..2c72c53a33 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.ChannelHandlerContext; +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import org.apache.spark.network.server.ChunkFetchRequestHandler; +import org.junit.Test; + +import static org.mockito.Mockito.*; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; + +public class ChunkFetchRequestHandlerSuite { + + @Test + public void handleChunkFetchRequest() throws Exception { + RpcHandler rpcHandler = new NoOpRpcHandler(); + OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); + Channel channel = mock(Channel.class); + ChannelHandlerContext context = mock(ChannelHandlerContext.class); + when(context.channel()) + .thenAnswer(invocationOnMock0 -> { + return channel; + }); + List> responseAndPromisePairs = + new ArrayList<>(); + when(channel.writeAndFlush(any())) + .thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + // Prepare the stream. + List managedBuffers = new ArrayList<>(); + managedBuffers.add(new TestManagedBuffer(10)); + managedBuffers.add(new TestManagedBuffer(20)); + managedBuffers.add(new TestManagedBuffer(30)); + managedBuffers.add(new TestManagedBuffer(40)); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); + streamManager.registerChannel(channel, streamId); + TransportClient reverseClient = mock(TransportClient.class); + ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient, + rpcHandler.getStreamManager(), 2L); + + RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); + requestHandler.channelRead(context, request0); + assert responseAndPromisePairs.size() == 1; + assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == + managedBuffers.get(0); + + RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); + requestHandler.channelRead(context, request1); + assert responseAndPromisePairs.size() == 2; + assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == + managedBuffers.get(1); + + // Finish flushing the response for request0. + responseAndPromisePairs.get(0).getRight().finish(true); + + RequestMessage request2 = new ChunkFetchRequest(new StreamChunkId(streamId, 2)); + requestHandler.channelRead(context, request2); + assert responseAndPromisePairs.size() == 3; + assert responseAndPromisePairs.get(2).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(2).getLeft())).body() == + managedBuffers.get(2); + + RequestMessage request3 = new ChunkFetchRequest(new StreamChunkId(streamId, 3)); + requestHandler.channelRead(context, request3); + verify(channel, times(1)).close(); + assert responseAndPromisePairs.size() == 3; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java b/common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java new file mode 100644 index 0000000000..573ffd627a --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +class ExtendedChannelPromise extends DefaultChannelPromise { + + private List>> listeners = new ArrayList<>(); + private boolean success; + + ExtendedChannelPromise(Channel channel) { + super(channel); + success = false; + } + + @Override + public ChannelPromise addListener( + GenericFutureListener> listener) { + @SuppressWarnings("unchecked") + GenericFutureListener> gfListener = + (GenericFutureListener>) listener; + listeners.add(gfListener); + return super.addListener(listener); + } + + @Override + public boolean isSuccess() { + return success; + } + + @Override + public ChannelPromise await() throws InterruptedException { + return this; + } + + public void finish(boolean success) { + this.success = success; + listeners.forEach(listener -> { + try { + listener.operationComplete(this); + } catch (Exception e) { + // do nothing + } + }); + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 2656cbee95..ad640415a8 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -21,10 +21,6 @@ import java.util.ArrayList; import java.util.List; import io.netty.channel.Channel; -import io.netty.channel.ChannelPromise; -import io.netty.channel.DefaultChannelPromise; -import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.GenericFutureListener; import org.junit.Test; import static org.mockito.Mockito.*; @@ -42,7 +38,7 @@ import org.apache.spark.network.server.TransportRequestHandler; public class TransportRequestHandlerSuite { @Test - public void handleFetchRequestAndStreamRequest() throws Exception { + public void handleStreamRequest() throws Exception { RpcHandler rpcHandler = new NoOpRpcHandler(); OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); Channel channel = mock(Channel.class); @@ -68,18 +64,18 @@ public class TransportRequestHandlerSuite { TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L); - RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); + RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0)); requestHandler.handle(request0); assert responseAndPromisePairs.size() == 1; - assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == + assert responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse; + assert ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body() == managedBuffers.get(0); - RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); + RequestMessage request1 = new StreamRequest(String.format("%d_%d", streamId, 1)); requestHandler.handle(request1); assert responseAndPromisePairs.size() == 2; - assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == + assert responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse; + assert ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body() == managedBuffers.get(1); // Finish flushing the response for request0. @@ -99,41 +95,4 @@ public class TransportRequestHandlerSuite { verify(channel, times(1)).close(); assert responseAndPromisePairs.size() == 3; } - - private class ExtendedChannelPromise extends DefaultChannelPromise { - - private List>> listeners = new ArrayList<>(); - private boolean success; - - ExtendedChannelPromise(Channel channel) { - super(channel); - success = false; - } - - @Override - public ChannelPromise addListener( - GenericFutureListener> listener) { - @SuppressWarnings("unchecked") - GenericFutureListener> gfListener = - (GenericFutureListener>) listener; - listeners.add(gfListener); - return super.addListener(listener); - } - - @Override - public boolean isSuccess() { - return success; - } - - public void finish(boolean success) { - this.success = success; - listeners.forEach(listener -> { - try { - listener.operationComplete(this); - } catch (Exception e) { - // do nothing - } - }); - } - } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 9a2cf0f953..e49e27ab5a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -76,7 +76,7 @@ public class ExternalShuffleClient extends ShuffleClient { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); List bootstraps = Lists.newArrayList(); if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));