From 4996c53949376153f9ebdc74524fed7226968808 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 10:56:48 +0800 Subject: [PATCH] [SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen ## What changes were proposed in this pull request? If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called. This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #18472 from zsxwing/fix-stream-2. --- .../spark/network/client/TransportClient.java | 2 +- .../client/TransportResponseHandler.java | 38 ++++++++++++++----- .../TransportResponseHandlerSuite.java | 31 ++++++++++++++- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a6f527c118..8f354ad78b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -179,7 +179,7 @@ public class TransportClient implements Closeable { // written to the socket atomically, so that callbacks are called in the right order // when responses arrive. synchronized (this) { - handler.addStreamCallback(callback); + handler.addStreamCallback(streamId, callback); channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 41bead546c..be9f18203c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; - private final Queue streamCallbacks; + private final Queue> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -88,9 +90,9 @@ public class TransportResponseHandler extends MessageHandler { outstandingRpcs.remove(requestId); } - public void addStreamCallback(StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(callback); + streamCallbacks.offer(Tuple2.apply(streamId, callback)); } @VisibleForTesting @@ -104,15 +106,31 @@ public class TransportResponseHandler extends MessageHandler { */ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { - entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + try { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } catch (Exception e) { + logger.warn("ChunkReceivedCallback.onFailure throws exception", e); + } } for (Map.Entry entry : outstandingRpcs.entrySet()) { - entry.getValue().onFailure(cause); + try { + entry.getValue().onFailure(cause); + } catch (Exception e) { + logger.warn("RpcResponseCallback.onFailure throws exception", e); + } + } + for (Tuple2 entry : streamCallbacks) { + try { + entry._2().onFailure(entry._1(), cause); + } catch (Exception e) { + logger.warn("StreamCallback.onFailure throws exception", e); + } } // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); outstandingRpcs.clear(); + streamCallbacks.clear(); } @Override @@ -190,8 +208,9 @@ public class TransportResponseHandler extends MessageHandler { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); if (resp.byteCount > 0) { StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); @@ -216,8 +235,9 @@ public class TransportResponseHandler extends MessageHandler { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); try { callback.onFailure(resp.streamId, new RuntimeException(resp.error)); } catch (IOException ioe) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 09fc80d12d..b4032c4c3f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.nio.ByteBuffer; import io.netty.channel.Channel; @@ -127,7 +128,7 @@ public class TransportResponseHandlerSuite { StreamResponse response = new StreamResponse("stream", 1234L, null); StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(response); assertEquals(1, handler.numOutstandingRequests()); @@ -135,9 +136,35 @@ public class TransportResponseHandlerSuite { assertEquals(0, handler.numOutstandingRequests()); StreamFailure failure = new StreamFailure("stream", "uh-oh"); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(failure); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void failOutstandingStreamCallbackOnClose() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.channelInactive(); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } + + @Test + public void failOutstandingStreamCallbackOnException() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.exceptionCaught(new IOException("Oops!")); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } }