[SPARK-11762][NETWORK] Account for active streams when couting outstanding requests.

This way the timeout handling code can correctly close "hung" channels that are
processing streams.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #9747 from vanzin/SPARK-11762.
This commit is contained in:
Marcelo Vanzin 2015-11-23 10:45:23 -08:00
parent 5fd86e4fc2
commit 5231cd5aca
3 changed files with 51 additions and 3 deletions

View file

@ -30,13 +30,19 @@ import org.apache.spark.network.util.TransportFrameDecoder;
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
private final TransportResponseHandler handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;
private volatile long bytesRead;
StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
StreamInterceptor(
TransportResponseHandler handler,
String streamId,
long byteCount,
StreamCallback callback) {
this.handler = handler;
this.streamId = streamId;
this.byteCount = byteCount;
this.callback = callback;
@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
@Override
public void exceptionCaught(Throwable cause) throws Exception {
handler.deactivateStream();
callback.onFailure(streamId, cause);
}
@Override
public void channelInactive() throws Exception {
handler.deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}
@ -65,8 +73,10 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
handler.deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
handler.deactivateStream();
callback.onComplete(streamId);
}

View file

@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
private final Queue<StreamCallback> streamCallbacks;
private volatile boolean streamActive;
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
@ -87,9 +89,15 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
public void addStreamCallback(StreamCallback callback) {
timeOfLastRequestNs.set(System.nanoTime());
streamCallbacks.offer(callback);
}
@VisibleForTesting
public void deactivateStream() {
streamActive = false;
}
/**
* Fire the failure callback for all outstanding requests. This is called when we have an
* uncaught exception or pre-mature connection termination.
@ -177,14 +185,16 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
logger.error("Could not find callback for StreamResponse.");
@ -208,7 +218,8 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
/** Returns total number of outstanding requests (fetch requests + rpcs) */
public int numOutstandingRequests() {
return outstandingFetches.size() + outstandingRpcs.size();
return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
(streamActive ? 1 : 0);
}
/** Returns the time in nanoseconds of when the last request was sent out. */

View file

@ -17,6 +17,7 @@
package org.apache.spark.network;
import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;
@ -28,12 +29,16 @@ import static org.mockito.Mockito.*;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.util.TransportFrameDecoder;
public class TransportResponseHandlerSuite {
@Test
@ -112,4 +117,26 @@ public class TransportResponseHandlerSuite {
verify(callback, times(1)).onFailure((Throwable) any());
assertEquals(0, handler.numOutstandingRequests());
}
@Test
public void testActiveStreams() {
Channel c = new LocalChannel();
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
TransportResponseHandler handler = new TransportResponseHandler(c);
StreamResponse response = new StreamResponse("stream", 1234L, null);
StreamCallback cb = mock(StreamCallback.class);
handler.addStreamCallback(cb);
assertEquals(1, handler.numOutstandingRequests());
handler.handle(response);
assertEquals(1, handler.numOutstandingRequests());
handler.deactivateStream();
assertEquals(0, handler.numOutstandingRequests());
StreamFailure failure = new StreamFailure("stream", "uh-oh");
handler.addStreamCallback(cb);
assertEquals(1, handler.numOutstandingRequests());
handler.handle(failure);
assertEquals(0, handler.numOutstandingRequests());
}
}