[SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen
## What changes were proposed in this pull request? If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called. This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #18472 from zsxwing/fix-stream-2.
This commit is contained in:
parent
f9151bebca
commit
4996c53949
|
@ -179,7 +179,7 @@ public class TransportClient implements Closeable {
|
|||
// written to the socket atomically, so that callbacks are called in the right order
|
||||
// when responses arrive.
|
||||
synchronized (this) {
|
||||
handler.addStreamCallback(callback);
|
||||
handler.addStreamCallback(streamId, callback);
|
||||
channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> {
|
||||
if (future.isSuccess()) {
|
||||
long timeTaken = System.currentTimeMillis() - startTime;
|
||||
|
|
|
@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import scala.Tuple2;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.channel.Channel;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
|
||||
private final Map<Long, RpcResponseCallback> outstandingRpcs;
|
||||
|
||||
private final Queue<StreamCallback> streamCallbacks;
|
||||
private final Queue<Tuple2<String, StreamCallback>> streamCallbacks;
|
||||
private volatile boolean streamActive;
|
||||
|
||||
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
|
||||
|
@ -88,9 +90,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
outstandingRpcs.remove(requestId);
|
||||
}
|
||||
|
||||
public void addStreamCallback(StreamCallback callback) {
|
||||
public void addStreamCallback(String streamId, StreamCallback callback) {
|
||||
timeOfLastRequestNs.set(System.nanoTime());
|
||||
streamCallbacks.offer(callback);
|
||||
streamCallbacks.offer(Tuple2.apply(streamId, callback));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
@ -104,15 +106,31 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
*/
|
||||
private void failOutstandingRequests(Throwable cause) {
|
||||
for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
|
||||
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
|
||||
try {
|
||||
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
|
||||
} catch (Exception e) {
|
||||
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
|
||||
}
|
||||
}
|
||||
for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
|
||||
entry.getValue().onFailure(cause);
|
||||
try {
|
||||
entry.getValue().onFailure(cause);
|
||||
} catch (Exception e) {
|
||||
logger.warn("RpcResponseCallback.onFailure throws exception", e);
|
||||
}
|
||||
}
|
||||
for (Tuple2<String, StreamCallback> entry : streamCallbacks) {
|
||||
try {
|
||||
entry._2().onFailure(entry._1(), cause);
|
||||
} catch (Exception e) {
|
||||
logger.warn("StreamCallback.onFailure throws exception", e);
|
||||
}
|
||||
}
|
||||
|
||||
// It's OK if new fetches appear, as they will fail immediately.
|
||||
outstandingFetches.clear();
|
||||
outstandingRpcs.clear();
|
||||
streamCallbacks.clear();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -190,8 +208,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
}
|
||||
} else if (message instanceof StreamResponse) {
|
||||
StreamResponse resp = (StreamResponse) message;
|
||||
StreamCallback callback = streamCallbacks.poll();
|
||||
if (callback != null) {
|
||||
Tuple2<String, StreamCallback> entry = streamCallbacks.poll();
|
||||
if (entry != null) {
|
||||
StreamCallback callback = entry._2();
|
||||
if (resp.byteCount > 0) {
|
||||
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
|
||||
callback);
|
||||
|
@ -216,8 +235,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
|
|||
}
|
||||
} else if (message instanceof StreamFailure) {
|
||||
StreamFailure resp = (StreamFailure) message;
|
||||
StreamCallback callback = streamCallbacks.poll();
|
||||
if (callback != null) {
|
||||
Tuple2<String, StreamCallback> entry = streamCallbacks.poll();
|
||||
if (entry != null) {
|
||||
StreamCallback callback = entry._2();
|
||||
try {
|
||||
callback.onFailure(resp.streamId, new RuntimeException(resp.error));
|
||||
} catch (IOException ioe) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.network;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
|
@ -127,7 +128,7 @@ public class TransportResponseHandlerSuite {
|
|||
|
||||
StreamResponse response = new StreamResponse("stream", 1234L, null);
|
||||
StreamCallback cb = mock(StreamCallback.class);
|
||||
handler.addStreamCallback(cb);
|
||||
handler.addStreamCallback("stream", cb);
|
||||
assertEquals(1, handler.numOutstandingRequests());
|
||||
handler.handle(response);
|
||||
assertEquals(1, handler.numOutstandingRequests());
|
||||
|
@ -135,9 +136,35 @@ public class TransportResponseHandlerSuite {
|
|||
assertEquals(0, handler.numOutstandingRequests());
|
||||
|
||||
StreamFailure failure = new StreamFailure("stream", "uh-oh");
|
||||
handler.addStreamCallback(cb);
|
||||
handler.addStreamCallback("stream", cb);
|
||||
assertEquals(1, handler.numOutstandingRequests());
|
||||
handler.handle(failure);
|
||||
assertEquals(0, handler.numOutstandingRequests());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failOutstandingStreamCallbackOnClose() throws Exception {
|
||||
Channel c = new LocalChannel();
|
||||
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
|
||||
TransportResponseHandler handler = new TransportResponseHandler(c);
|
||||
|
||||
StreamCallback cb = mock(StreamCallback.class);
|
||||
handler.addStreamCallback("stream-1", cb);
|
||||
handler.channelInactive();
|
||||
|
||||
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failOutstandingStreamCallbackOnException() throws Exception {
|
||||
Channel c = new LocalChannel();
|
||||
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
|
||||
TransportResponseHandler handler = new TransportResponseHandler(c);
|
||||
|
||||
StreamCallback cb = mock(StreamCallback.class);
|
||||
handler.addStreamCallback("stream-1", cb);
|
||||
handler.exceptionCaught(new IOException("Oops!"));
|
||||
|
||||
verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue