diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index c9ef9f918f..24c436a504 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -27,6 +27,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import com.codahale.metrics.MetricSet; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.Lists; @@ -61,6 +62,7 @@ public class TransportClientFactory implements Closeable { private static class ClientPool { TransportClient[] clients; Object[] locks; + volatile long lastConnectionFailed; ClientPool(int size) { clients = new TransportClient[size]; @@ -68,6 +70,7 @@ public class TransportClientFactory implements Closeable { for (int i = 0; i < size; i++) { locks[i] = new Object(); } + lastConnectionFailed = 0; } } @@ -86,6 +89,7 @@ public class TransportClientFactory implements Closeable { private EventLoopGroup workerGroup; private final PooledByteBufAllocator pooledAllocator; private final NettyMemoryMetrics metrics; + private final int fastFailTimeWindow; public TransportClientFactory( TransportContext context, @@ -112,6 +116,7 @@ public class TransportClientFactory implements Closeable { } this.metrics = new NettyMemoryMetrics( this.pooledAllocator, conf.getModuleName() + "-client", conf); + fastFailTimeWindow = (int)(conf.ioRetryWaitTimeMs() * 0.95); } public MetricSet getAllMetrics() { @@ -121,18 +126,27 @@ public class TransportClientFactory implements Closeable { /** * Create a {@link TransportClient} connecting to the given remote host / port. * - * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) + * We maintain an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) * and randomly picks one to use. If no client was previously created in the randomly selected * spot, this function creates a new client and places it there. * + * If the fastFail parameter is true, fail immediately when the last attempt to the same address + * failed within the fast fail time window (95 percent of the io wait retry timeout). The + * assumption is the caller will handle retrying. + * * Prior to the creation of a new TransportClient, we will execute all * {@link TransportClientBootstrap}s that are registered with this factory. * * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. + * + * @param remoteHost remote address host + * @param remotePort remote address port + * @param fastFail whether this call should fail immediately when the last attempt to the same + * address failed with in the last fast fail time window. */ - public TransportClient createClient(String remoteHost, int remotePort) + public TransportClient createClient(String remoteHost, int remotePort, boolean fastFail) throws IOException, InterruptedException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. @@ -192,11 +206,30 @@ public class TransportClientFactory implements Closeable { logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } - clientPool.clients[clientIndex] = createClient(resolvedAddress); + // If this connection should fast fail when last connection failed in last fast fail time + // window and it did, fail this connection directly. + if (fastFail && System.currentTimeMillis() - clientPool.lastConnectionFailed < + fastFailTimeWindow) { + throw new IOException( + String.format("Connecting to %s failed in the last %s ms, fail this connection directly", + resolvedAddress, fastFailTimeWindow)); + } + try { + clientPool.clients[clientIndex] = createClient(resolvedAddress); + clientPool.lastConnectionFailed = 0; + } catch (IOException e) { + clientPool.lastConnectionFailed = System.currentTimeMillis(); + throw e; + } return clientPool.clients[clientIndex]; } } + public TransportClient createClient(String remoteHost, int remotePort) + throws IOException, InterruptedException { + return createClient(remoteHost, remotePort, false); + } + /** * Create a completely new {@link TransportClient} to the given remote host / port. * This connection is not pooled. @@ -210,7 +243,8 @@ public class TransportClientFactory implements Closeable { } /** Create a completely new {@link TransportClient} to the remote address. */ - private TransportClient createClient(InetSocketAddress address) + @VisibleForTesting + TransportClient createClient(InetSocketAddress address) throws IOException, InterruptedException { logger.debug("Creating new connection to {}", address); diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java similarity index 88% rename from common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java rename to common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java index 9b76981c31..ea0ac51589 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network; +package org.apache.spark.network.client; import java.io.IOException; import java.util.Collections; @@ -29,14 +29,16 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Assert; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertTrue; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; @@ -224,4 +226,24 @@ public class TransportClientFactorySuite { factory.close(); factory.createClient(TestUtils.getLocalHost(), server1.getPort()); } + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void fastFailConnectionInTimeWindow() throws IOException, InterruptedException { + TransportClientFactory factory = context.createClientFactory(); + TransportServer server = context.createServer(); + int unreachablePort = server.getPort(); + server.close(); + try { + factory.createClient(TestUtils.getLocalHost(), unreachablePort, true); + } catch (Exception e) { + assert(e instanceof IOException); + } + expectedException.expect(IOException.class); + expectedException.expectMessage("fail this connection directly"); + factory.createClient(TestUtils.getLocalHost(), unreachablePort, true); + expectedException = ExpectedException.none(); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index d6185f089d..51dc333726 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -101,11 +101,12 @@ public class ExternalBlockStoreClient extends BlockStoreClient { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { + int maxRetries = conf.maxIORetries(); RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { // Unless this client is closed. if (clientFactory != null) { - TransportClient client = clientFactory.createClient(host, port); + TransportClient client = clientFactory.createClient(host, port, maxRetries > 0); new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, downloadFileManager).start(); } else { @@ -113,7 +114,6 @@ public class ExternalBlockStoreClient extends BlockStoreClient { } }; - int maxRetries = conf.maxIORetries(); if (maxRetries > 0) { // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's // a bug in this code. We should remove the if statement once we're sure of the stability. diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ffb696029a..3de7377f99 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -115,11 +115,12 @@ private[spark] class NettyBlockTransferService( tempFileManager: DownloadFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { + val maxRetries = transportConf.maxIORetries() val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener): Unit = { try { - val client = clientFactory.createClient(host, port) + val client = clientFactory.createClient(host, port, maxRetries > 0) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, transportConf, tempFileManager).start() } catch { @@ -136,7 +137,6 @@ private[spark] class NettyBlockTransferService( } } - val maxRetries = transportConf.maxIORetries() if (maxRetries > 0) { // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's // a bug in this code. We should remove the if statement once we're sure of the stability. diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index edddf88a28..c804102e4a 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -105,7 +105,7 @@ class NettyBlockTransferServiceSuite // This is used to touch an IOException during fetching block. when(client.sendRpc(any(), any())).thenAnswer(_ => {throw new IOException()}) var createClientCount = 0 - when(clientFactory.createClient(any(), any())).thenAnswer(_ => { + when(clientFactory.createClient(any(), any(), any())).thenAnswer(_ => { createClientCount += 1 client })