[SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen

## What changes were proposed in this pull request?

If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called.

This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks.

## How was this patch tested?

The new unit tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #18472 from zsxwing/fix-stream-2.
This commit is contained in:
Shixiong Zhu 2017-06-30 10:56:48 +08:00 committed by Wenchen Fan
parent f9151bebca
commit 4996c53949
3 changed files with 59 additions and 12 deletions

View file

@ -179,7 +179,7 @@ public class TransportClient implements Closeable {
// written to the socket atomically, so that callbacks are called in the right order
// when responses arrive.
synchronized (this) {
handler.addStreamCallback(callback);
handler.addStreamCallback(streamId, callback);
channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;

View file

@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import scala.Tuple2;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;
private final Queue<StreamCallback> streamCallbacks;
private final Queue<Tuple2<String, StreamCallback>> streamCallbacks;
private volatile boolean streamActive;
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
@ -88,9 +90,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
outstandingRpcs.remove(requestId);
}
public void addStreamCallback(StreamCallback callback) {
public void addStreamCallback(String streamId, StreamCallback callback) {
timeOfLastRequestNs.set(System.nanoTime());
streamCallbacks.offer(callback);
streamCallbacks.offer(Tuple2.apply(streamId, callback));
}
@VisibleForTesting
@ -104,15 +106,31 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
*/
private void failOutstandingRequests(Throwable cause) {
for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
try {
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
} catch (Exception e) {
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
}
}
for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
entry.getValue().onFailure(cause);
try {
entry.getValue().onFailure(cause);
} catch (Exception e) {
logger.warn("RpcResponseCallback.onFailure throws exception", e);
}
}
for (Tuple2<String, StreamCallback> entry : streamCallbacks) {
try {
entry._2().onFailure(entry._1(), cause);
} catch (Exception e) {
logger.warn("StreamCallback.onFailure throws exception", e);
}
}
// It's OK if new fetches appear, as they will fail immediately.
outstandingFetches.clear();
outstandingRpcs.clear();
streamCallbacks.clear();
}
@Override
@ -190,8 +208,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
} else if (message instanceof StreamResponse) {
StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
Tuple2<String, StreamCallback> entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry._2();
if (resp.byteCount > 0) {
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
callback);
@ -216,8 +235,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
} else if (message instanceof StreamFailure) {
StreamFailure resp = (StreamFailure) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
Tuple2<String, StreamCallback> entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry._2();
try {
callback.onFailure(resp.streamId, new RuntimeException(resp.error));
} catch (IOException ioe) {

View file

@ -17,6 +17,7 @@
package org.apache.spark.network;
import java.io.IOException;
import java.nio.ByteBuffer;
import io.netty.channel.Channel;
@ -127,7 +128,7 @@ public class TransportResponseHandlerSuite {
StreamResponse response = new StreamResponse("stream", 1234L, null);
StreamCallback cb = mock(StreamCallback.class);
handler.addStreamCallback(cb);
handler.addStreamCallback("stream", cb);
assertEquals(1, handler.numOutstandingRequests());
handler.handle(response);
assertEquals(1, handler.numOutstandingRequests());
@ -135,9 +136,35 @@ public class TransportResponseHandlerSuite {
assertEquals(0, handler.numOutstandingRequests());
StreamFailure failure = new StreamFailure("stream", "uh-oh");
handler.addStreamCallback(cb);
handler.addStreamCallback("stream", cb);
assertEquals(1, handler.numOutstandingRequests());
handler.handle(failure);
assertEquals(0, handler.numOutstandingRequests());
}
@Test
public void failOutstandingStreamCallbackOnClose() throws Exception {
Channel c = new LocalChannel();
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
TransportResponseHandler handler = new TransportResponseHandler(c);
StreamCallback cb = mock(StreamCallback.class);
handler.addStreamCallback("stream-1", cb);
handler.channelInactive();
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
}
@Test
public void failOutstandingStreamCallbackOnException() throws Exception {
Channel c = new LocalChannel();
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
TransportResponseHandler handler = new TransportResponseHandler(c);
StreamCallback cb = mock(StreamCallback.class);
handler.addStreamCallback("stream-1", cb);
handler.exceptionCaught(new IOException("Oops!"));
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
}
}