[SPARK-11762][NETWORK] Account for active streams when couting outstanding requests.
This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9747 from vanzin/SPARK-11762.
This commit is contained in:
parent
5fd86e4fc2
commit
5231cd5aca
|
@ -30,13 +30,19 @@ import org.apache.spark.network.util.TransportFrameDecoder;
|
|||
*/
|
||||
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
|
||||
|
||||
private final TransportResponseHandler handler;
|
||||
private final String streamId;
|
||||
private final long byteCount;
|
||||
private final StreamCallback callback;
|
||||
|
||||
private volatile long bytesRead;
|
||||
|
||||
StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
|
||||
StreamInterceptor(
|
||||
TransportResponseHandler handler,
|
||||
String streamId,
|
||||
long byteCount,
|
||||
StreamCallback callback) {
|
||||
this.handler = handler;
|
||||
this.streamId = streamId;
|
||||
this.byteCount = byteCount;
|
||||
this.callback = callback;
|
||||
|
@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
|
|||
|
||||
@Override
|
||||
public void exceptionCaught(Throwable cause) throws Exception {
|
||||
handler.deactivateStream();
|
||||
callback.onFailure(streamId, cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive() throws Exception {
|
||||
handler.deactivateStream();
|
||||
callback.onFailure(streamId, new ClosedChannelException());
|
||||
}
|
||||
|
||||
|
@ -65,8 +73,10 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
|
|||
RuntimeException re = new IllegalStateException(String.format(
|
||||
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
|
||||
callback.onFailure(streamId, re);
|
||||
handler.deactivateStream();
|
||||
throw re;
|
||||
} else if (bytesRead == byteCount) {
|
||||
handler.deactivateStream();
|
||||
callback.onComplete(streamId);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.channel.Channel;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
private final Map<Long, RpcResponseCallback> outstandingRpcs;
|
||||
|
||||
private final Queue<StreamCallback> streamCallbacks;
|
||||
private volatile boolean streamActive;
|
||||
|
||||
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
|
||||
private final AtomicLong timeOfLastRequestNs;
|
||||
|
@ -87,9 +89,15 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
}
|
||||
|
||||
public void addStreamCallback(StreamCallback callback) {
|
||||
timeOfLastRequestNs.set(System.nanoTime());
|
||||
streamCallbacks.offer(callback);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public void deactivateStream() {
|
||||
streamActive = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire the failure callback for all outstanding requests. This is called when we have an
|
||||
* uncaught exception or pre-mature connection termination.
|
||||
|
@ -177,14 +185,16 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
StreamResponse resp = (StreamResponse) message;
|
||||
StreamCallback callback = streamCallbacks.poll();
|
||||
if (callback != null) {
|
||||
StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
|
||||
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
|
||||
callback);
|
||||
try {
|
||||
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
|
||||
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
|
||||
frameDecoder.setInterceptor(interceptor);
|
||||
streamActive = true;
|
||||
} catch (Exception e) {
|
||||
logger.error("Error installing stream handler.", e);
|
||||
deactivateStream();
|
||||
}
|
||||
} else {
|
||||
logger.error("Could not find callback for StreamResponse.");
|
||||
|
@ -208,7 +218,8 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
|
||||
/** Returns total number of outstanding requests (fetch requests + rpcs) */
|
||||
public int numOutstandingRequests() {
|
||||
return outstandingFetches.size() + outstandingRpcs.size();
|
||||
return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
|
||||
(streamActive ? 1 : 0);
|
||||
}
|
||||
|
||||
/** Returns the time in nanoseconds of when the last request was sent out. */
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.network;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -28,12 +29,16 @@ import static org.mockito.Mockito.*;
|
|||
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||
import org.apache.spark.network.client.ChunkReceivedCallback;
|
||||
import org.apache.spark.network.client.RpcResponseCallback;
|
||||
import org.apache.spark.network.client.StreamCallback;
|
||||
import org.apache.spark.network.client.TransportResponseHandler;
|
||||
import org.apache.spark.network.protocol.ChunkFetchFailure;
|
||||
import org.apache.spark.network.protocol.ChunkFetchSuccess;
|
||||
import org.apache.spark.network.protocol.RpcFailure;
|
||||
import org.apache.spark.network.protocol.RpcResponse;
|
||||
import org.apache.spark.network.protocol.StreamChunkId;
|
||||
import org.apache.spark.network.protocol.StreamFailure;
|
||||
import org.apache.spark.network.protocol.StreamResponse;
|
||||
import org.apache.spark.network.util.TransportFrameDecoder;
|
||||
|
||||
public class TransportResponseHandlerSuite {
|
||||
@Test
|
||||
|
@ -112,4 +117,26 @@ public class TransportResponseHandlerSuite {
|
|||
verify(callback, times(1)).onFailure((Throwable) any());
|
||||
assertEquals(0, handler.numOutstandingRequests());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActiveStreams() {
|
||||
Channel c = new LocalChannel();
|
||||
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
|
||||
TransportResponseHandler handler = new TransportResponseHandler(c);
|
||||
|
||||
StreamResponse response = new StreamResponse("stream", 1234L, null);
|
||||
StreamCallback cb = mock(StreamCallback.class);
|
||||
handler.addStreamCallback(cb);
|
||||
assertEquals(1, handler.numOutstandingRequests());
|
||||
handler.handle(response);
|
||||
assertEquals(1, handler.numOutstandingRequests());
|
||||
handler.deactivateStream();
|
||||
assertEquals(0, handler.numOutstandingRequests());
|
||||
|
||||
StreamFailure failure = new StreamFailure("stream", "uh-oh");
|
||||
handler.addStreamCallback(cb);
|
||||
assertEquals(1, handler.numOutstandingRequests());
|
||||
handler.handle(failure);
|
||||
assertEquals(0, handler.numOutstandingRequests());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue