[SPARK-27021][CORE] Cleanup of Netty event loop group for shuffle chunk fetch requests

## What changes were proposed in this pull request?

Creating an Netty `EventLoopGroup` leads to creating a new Thread pool for handling the events. For stopping the threads of the pool the event loop group should be shut down which is properly done for transport servers and clients by calling for example the `shutdownGracefully()` method (for details see the `close()` method of `TransportClientFactory` and `TransportServer`). But there is a separate event loop group for shuffle chunk fetch requests which is in pipeline for handling fetch request (shared between the client and server) and owned by the `TransportContext` and this was never shut down.

## How was this patch tested?

With existing unittest.

This leak is in the production system too but its effect is spiking in the unittest.

Checking the core unittest logs before the PR:
```
$ grep "LEAK IN SUITE" unit-tests.log | grep -o shuffle-chunk-fetch-handler | wc -l
381
```

And after the PR without whitelisting in thread audit and with an extra `await` after the
` chunkFetchWorkers.shutdownGracefully()`:
```
$ grep "LEAK IN SUITE" unit-tests.log | grep -o shuffle-chunk-fetch-handler | wc -l
0
```

Closes #23930 from attilapiros/SPARK-27021.

Authored-by: “attilapiros” <piros.attila.zsolt@gmail.com>
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
This commit is contained in:
“attilapiros” 2019-03-05 12:31:06 -08:00 committed by Marcelo Vanzin
parent c27caead43
commit 5668c42edf
19 changed files with 225 additions and 170 deletions

View file

@ -17,6 +17,7 @@
package org.apache.spark.network;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.List;
@ -60,13 +61,12 @@ import org.apache.spark.network.util.TransportFrameDecoder;
* channel. As each TransportChannelHandler contains a TransportClient, this enables server
* processes to send messages back to the client on an existing channel.
*/
public class TransportContext {
public class TransportContext implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(TransportContext.class);
private final TransportConf conf;
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;
private final boolean isClientOnly;
// Number of registered connections to the shuffle service
private Counter registeredConnections = new Counter();
@ -120,7 +120,6 @@ public class TransportContext {
this.conf = conf;
this.rpcHandler = rpcHandler;
this.closeIdleConnections = closeIdleConnections;
this.isClientOnly = isClientOnly;
if (conf.getModuleName() != null &&
conf.getModuleName().equalsIgnoreCase("shuffle") &&
@ -200,9 +199,7 @@ public class TransportContext {
// would require more logic to guarantee if this were not part of the same event loop.
.addLast("handler", channelHandler);
// Use a separate EventLoopGroup to handle ChunkFetchRequest messages for shuffle rpcs.
if (conf.getModuleName() != null &&
conf.getModuleName().equalsIgnoreCase("shuffle")
&& !isClientOnly) {
if (chunkFetchWorkers != null) {
pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", chunkFetchHandler);
}
return channelHandler;
@ -240,4 +237,10 @@ public class TransportContext {
public Counter getRegisteredConnections() {
return registeredConnections;
}
public void close() {
if (chunkFetchWorkers != null) {
chunkFetchWorkers.shutdownGracefully();
}
}
}

View file

@ -56,6 +56,7 @@ public class ChunkFetchIntegrationSuite {
static final int BUFFER_CHUNK_INDEX = 0;
static final int FILE_CHUNK_INDEX = 1;
static TransportContext context;
static TransportServer server;
static TransportClientFactory clientFactory;
static StreamManager streamManager;
@ -117,7 +118,7 @@ public class ChunkFetchIntegrationSuite {
return streamManager;
}
};
TransportContext context = new TransportContext(conf, handler);
context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
@ -127,6 +128,7 @@ public class ChunkFetchIntegrationSuite {
bufferChunk.release();
server.close();
clientFactory.close();
context.close();
testFile.delete();
}

View file

@ -48,6 +48,7 @@ import java.util.concurrent.TimeUnit;
*/
public class RequestTimeoutIntegrationSuite {
private TransportContext context;
private TransportServer server;
private TransportClientFactory clientFactory;
@ -79,6 +80,9 @@ public class RequestTimeoutIntegrationSuite {
if (clientFactory != null) {
clientFactory.close();
}
if (context != null) {
context.close();
}
}
// Basic suite: First request completes quickly, and second waits for longer than network timeout.
@ -106,7 +110,7 @@ public class RequestTimeoutIntegrationSuite {
}
};
TransportContext context = new TransportContext(conf, handler);
context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
@ -153,7 +157,7 @@ public class RequestTimeoutIntegrationSuite {
}
};
TransportContext context = new TransportContext(conf, handler);
context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
@ -204,7 +208,7 @@ public class RequestTimeoutIntegrationSuite {
}
};
TransportContext context = new TransportContext(conf, handler);
context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());

View file

@ -44,6 +44,7 @@ import org.apache.spark.network.util.TransportConf;
public class RpcIntegrationSuite {
static TransportConf conf;
static TransportContext context;
static TransportServer server;
static TransportClientFactory clientFactory;
static RpcHandler rpcHandler;
@ -90,7 +91,7 @@ public class RpcIntegrationSuite {
@Override
public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
};
TransportContext context = new TransportContext(conf, rpcHandler);
context = new TransportContext(conf, rpcHandler);
server = context.createServer();
clientFactory = context.createClientFactory();
oneWayMsgs = new ArrayList<>();
@ -160,6 +161,7 @@ public class RpcIntegrationSuite {
public static void tearDown() {
server.close();
clientFactory.close();
context.close();
testData.cleanup();
}

View file

@ -51,6 +51,7 @@ public class StreamSuite {
private static final String[] STREAMS = StreamTestHelper.STREAMS;
private static StreamTestHelper testData;
private static TransportContext context;
private static TransportServer server;
private static TransportClientFactory clientFactory;
@ -93,7 +94,7 @@ public class StreamSuite {
return streamManager;
}
};
TransportContext context = new TransportContext(conf, handler);
context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
@ -103,6 +104,7 @@ public class StreamSuite {
server.close();
clientFactory.close();
testData.cleanup();
context.close();
}
@Test

View file

@ -64,6 +64,7 @@ public class TransportClientFactorySuite {
public void tearDown() {
JavaUtils.closeQuietly(server1);
JavaUtils.closeQuietly(server2);
JavaUtils.closeQuietly(context);
}
/**
@ -80,49 +81,50 @@ public class TransportClientFactorySuite {
TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
RpcHandler rpcHandler = new NoOpRpcHandler();
TransportContext context = new TransportContext(conf, rpcHandler);
TransportClientFactory factory = context.createClientFactory();
Set<TransportClient> clients = Collections.synchronizedSet(
new HashSet<TransportClient>());
try (TransportContext context = new TransportContext(conf, rpcHandler)) {
TransportClientFactory factory = context.createClientFactory();
Set<TransportClient> clients = Collections.synchronizedSet(
new HashSet<TransportClient>());
AtomicInteger failed = new AtomicInteger();
Thread[] attempts = new Thread[maxConnections * 10];
AtomicInteger failed = new AtomicInteger();
Thread[] attempts = new Thread[maxConnections * 10];
// Launch a bunch of threads to create new clients.
for (int i = 0; i < attempts.length; i++) {
attempts[i] = new Thread(() -> {
try {
TransportClient client =
factory.createClient(TestUtils.getLocalHost(), server1.getPort());
assertTrue(client.isActive());
clients.add(client);
} catch (IOException e) {
failed.incrementAndGet();
} catch (InterruptedException e) {
throw new RuntimeException(e);
// Launch a bunch of threads to create new clients.
for (int i = 0; i < attempts.length; i++) {
attempts[i] = new Thread(() -> {
try {
TransportClient client =
factory.createClient(TestUtils.getLocalHost(), server1.getPort());
assertTrue(client.isActive());
clients.add(client);
} catch (IOException e) {
failed.incrementAndGet();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
if (concurrent) {
attempts[i].start();
} else {
attempts[i].run();
}
});
if (concurrent) {
attempts[i].start();
} else {
attempts[i].run();
}
// Wait until all the threads complete.
for (Thread attempt : attempts) {
attempt.join();
}
Assert.assertEquals(0, failed.get());
Assert.assertEquals(clients.size(), maxConnections);
for (TransportClient client : clients) {
client.close();
}
factory.close();
}
// Wait until all the threads complete.
for (Thread attempt : attempts) {
attempt.join();
}
Assert.assertEquals(0, failed.get());
Assert.assertEquals(clients.size(), maxConnections);
for (TransportClient client : clients) {
client.close();
}
factory.close();
}
@Test
@ -204,8 +206,8 @@ public class TransportClientFactorySuite {
throw new UnsupportedOperationException();
}
});
TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
try (TransportClientFactory factory = context.createClientFactory()) {
try (TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
TransportClientFactory factory = context.createClientFactory()) {
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
assertTrue(c1.isActive());
long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds

View file

@ -196,6 +196,9 @@ public class AuthIntegrationSuite {
if (server != null) {
server.close();
}
if (ctx != null) {
ctx.close();
}
}
private SecretKeyHolder createKeyHolder(String secret) {

View file

@ -365,6 +365,7 @@ public class SparkSaslSuite {
final TransportClient client;
final TransportServer server;
final TransportContext ctx;
private final boolean encrypt;
private final boolean disableClientEncryption;
@ -396,7 +397,7 @@ public class SparkSaslSuite {
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
this.ctx = new TransportContext(conf, rpcHandler);
this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
@ -431,6 +432,9 @@ public class SparkSaslSuite {
if (server != null) {
server.close();
}
if (ctx != null) {
ctx.close();
}
}
}

View file

@ -60,11 +60,14 @@ public class NettyMemoryMetricsSuite {
JavaUtils.closeQuietly(clientFactory);
clientFactory = null;
}
if (server != null) {
JavaUtils.closeQuietly(server);
server = null;
}
if (context != null) {
JavaUtils.closeQuietly(context);
context = null;
}
}
@Test

View file

@ -91,6 +91,7 @@ public class SaslIntegrationSuite {
@AfterClass
public static void afterAll() {
server.close();
context.close();
}
@After
@ -153,13 +154,14 @@ public class SaslIntegrationSuite {
@Test
public void testNoSaslServer() {
RpcHandler handler = new TestRpcHandler();
TransportContext context = new TransportContext(conf, handler);
clientFactory = context.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
try (TransportServer server = context.createServer()) {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
try (TransportContext context = new TransportContext(conf, handler)) {
clientFactory = context.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
try (TransportServer server = context.createServer()) {
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
}
}
}
@ -174,18 +176,15 @@ public class SaslIntegrationSuite {
ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
new OneForOneStreamManager(), blockResolver);
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
TransportClient client1 = null;
TransportClient client2 = null;
TransportClientFactory clientFactory2 = null;
try {
try (
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
// Create a client, and make a request to fetch blocks from a different app.
clientFactory = blockServerContext.createClientFactory(
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
client1 = clientFactory.createClient(TestUtils.getLocalHost(),
blockServer.getPort());
TransportClient client1 = clientFactory1.createClient(
TestUtils.getLocalHost(), blockServer.getPort())) {
AtomicReference<Throwable> exception = new AtomicReference<>();
@ -223,41 +222,33 @@ public class SaslIntegrationSuite {
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
long streamId = stream.streamId;
// Create a second client, authenticated with a different app ID, and try to read from
// the stream created for the previous app.
clientFactory2 = blockServerContext.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
blockServer.getPort());
try (
// Create a second client, authenticated with a different app ID, and try to read from
// the stream created for the previous app.
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
TransportClient client2 = clientFactory2.createClient(
TestUtils.getLocalHost(), blockServer.getPort())
) {
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
chunkReceivedLatch.countDown();
}
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
chunkReceivedLatch.countDown();
}
@Override
public void onFailure(int chunkIndex, Throwable t) {
exception.set(t);
chunkReceivedLatch.countDown();
}
};
@Override
public void onFailure(int chunkIndex, Throwable t) {
exception.set(t);
chunkReceivedLatch.countDown();
}
};
exception.set(null);
client2.fetchChunk(streamId, 0, callback);
chunkReceivedLatch.await();
checkSecurityException(exception.get());
} finally {
if (client1 != null) {
client1.close();
exception.set(null);
client2.fetchChunk(streamId, 0, callback);
chunkReceivedLatch.await();
checkSecurityException(exception.get());
}
if (client2 != null) {
client2.close();
}
if (clientFactory2 != null) {
clientFactory2.close();
}
blockServer.close();
}
}

View file

@ -58,6 +58,7 @@ public class ExternalShuffleIntegrationSuite {
static ExternalShuffleBlockHandler handler;
static TransportServer server;
static TransportConf conf;
static TransportContext transportContext;
static byte[][] exec0Blocks = new byte[][] {
new byte[123],
@ -87,7 +88,7 @@ public class ExternalShuffleIntegrationSuite {
conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
handler = new ExternalShuffleBlockHandler(conf, null);
TransportContext transportContext = new TransportContext(conf, handler);
transportContext = new TransportContext(conf, handler);
server = transportContext.createServer();
}
@ -95,6 +96,7 @@ public class ExternalShuffleIntegrationSuite {
public static void afterAll() {
dataContext0.cleanup();
server.close();
transportContext.close();
}
@After

View file

@ -41,14 +41,14 @@ public class ExternalShuffleSecuritySuite {
TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
TransportServer server;
TransportContext transportContext;
@Before
public void beforeEach() throws IOException {
TransportContext context =
new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null));
transportContext = new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null));
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
new TestSecretKeyHolder("my-app-id", "secret"));
this.server = context.createServer(Arrays.asList(bootstrap));
this.server = transportContext.createServer(Arrays.asList(bootstrap));
}
@After
@ -57,6 +57,10 @@ public class ExternalShuffleSecuritySuite {
server.close();
server = null;
}
if (transportContext != null) {
transportContext.close();
transportContext = null;
}
}
@Test

View file

@ -113,6 +113,8 @@ public class YarnShuffleService extends AuxiliaryService {
// The actual server that serves shuffle files
private TransportServer shuffleServer = null;
private TransportContext transportContext = null;
private Configuration _conf = null;
// The recovery path used to shuffle service recovery
@ -184,7 +186,7 @@ public class YarnShuffleService extends AuxiliaryService {
int port = conf.getInt(
SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
TransportContext transportContext = new TransportContext(transportConf, blockHandler);
transportContext = new TransportContext(transportConf, blockHandler);
shuffleServer = transportContext.createServer(port, bootstraps);
// the port should normally be fixed, but for tests its useful to find an open port
port = shuffleServer.getPort();
@ -318,6 +320,9 @@ public class YarnShuffleService extends AuxiliaryService {
if (shuffleServer != null) {
shuffleServer.close();
}
if (transportContext != null) {
transportContext.close();
}
if (blockHandler != null) {
blockHandler.close();
}

View file

@ -52,8 +52,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val transportConf =
SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0)
private val blockHandler = newShuffleBlockHandler(transportConf)
private val transportContext: TransportContext =
new TransportContext(transportConf, blockHandler, true)
private var transportContext: TransportContext = _
private var server: TransportServer = _
@ -82,6 +81,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
} else {
Nil
}
transportContext = new TransportContext(transportConf, blockHandler, true)
server = transportContext.createServer(port, bootstraps.asJava)
shuffleServiceSource.registerMetricSet(server.getAllMetrics)
@ -107,6 +107,10 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
server.close()
server = null
}
if (transportContext != null) {
transportContext.close()
transportContext = null
}
}
}

View file

@ -182,5 +182,8 @@ private[spark] class NettyBlockTransferService(
if (clientFactory != null) {
clientFactory.close()
}
if (transportContext != null) {
transportContext.close()
}
}
}

View file

@ -315,6 +315,9 @@ private[netty] class NettyRpcEnv(
if (fileDownloadFactory != null) {
fileDownloadFactory.close()
}
if (transportContext != null) {
transportContext.close()
}
}
override def deserialize[T](deserializationAction: () => T): T = {

View file

@ -24,6 +24,7 @@ import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient}
import org.apache.spark.util.Utils
/**
* This suite creates an external shuffle server and routes all shuffle fetches through it.
@ -33,13 +34,14 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh
*/
class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
var server: TransportServer = _
var transportContext: TransportContext = _
var rpcHandler: ExternalShuffleBlockHandler = _
override def beforeAll() {
super.beforeAll()
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2)
rpcHandler = new ExternalShuffleBlockHandler(transportConf, null)
val transportContext = new TransportContext(transportConf, rpcHandler)
transportContext = new TransportContext(transportConf, rpcHandler)
server = transportContext.createServer()
conf.set(config.SHUFFLE_MANAGER, "sort")
@ -48,11 +50,16 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
}
override def afterAll() {
try {
Utils.tryLogNonFatalError{
server.close()
} finally {
super.afterAll()
}
Utils.tryLogNonFatalError{
rpcHandler.close()
}
Utils.tryLogNonFatalError{
transportContext.close()
}
super.afterAll()
}
// This test ensures that the external shuffle service is actually in use for the other tests.

View file

@ -55,18 +55,26 @@ trait ThreadAudit extends Logging {
* creates event loops. One is wrapped inside
* [[org.apache.spark.network.server.TransportServer]]
* the other one is inside [[org.apache.spark.network.client.TransportClient]].
* The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]].
* Manually checked and all of them stopped properly.
* Calling [[SparkContext#stop]] will shut down the thread pool of this event group
* asynchronously. In each case proper stopping is checked manually.
*/
"rpc-client.*",
"rpc-server.*",
/**
* During [[org.apache.spark.network.TransportContext]] construction a separate event loop could
* be created for handling ChunkFetchRequest.
* Calling [[org.apache.spark.network.TransportContext#close]] will shut down the thread pool
* of this event group asynchronously. In each case proper stopping is checked manually.
*/
"shuffle-chunk-fetch-handler.*",
/**
* During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside
* [[org.apache.spark.network.server.TransportServer]]
* the other one is inside [[org.apache.spark.network.client.TransportClient]].
* The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]].
* Manually checked and all of them stopped properly.
* Calling [[SparkContext#stop]] will shut down the thread pool of this event group
* asynchronously. In each case proper stopping is checked manually.
*/
"shuffle-client.*",
"shuffle-server.*"

View file

@ -895,6 +895,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
serializerManager, conf, memoryManager, mapOutputTracker,
shuffleManager, transfer, securityMgr, 0)
allStores += store
store.initialize("app-id")
// The put should fail since a1 is not serializable.
@ -1360,74 +1361,76 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val tryAgainExecutor = "tryAgainExecutor"
val succeedingExecutor = "succeedingExecutor"
// a server which delays response 50ms and must try twice for success.
def newShuffleServer(port: Int): (TransportServer, Int) = {
val failure = new Exception(tryAgainMsg)
val success = ByteBuffer.wrap(new Array[Byte](0))
val failure = new Exception(tryAgainMsg)
val success = ByteBuffer.wrap(new Array[Byte](0))
var secondExecutorFailedOnce = false
var thirdExecutorFailedOnce = false
var secondExecutorFailedOnce = false
var thirdExecutorFailedOnce = false
val handler = new NoOpRpcHandler {
override def receive(
client: TransportClient,
message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message)
msgObj match {
val handler = new NoOpRpcHandler {
override def receive(
client: TransportClient,
message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message)
msgObj match {
case exec: RegisterExecutor if exec.execId == timingoutExecutor =>
() // No reply to generate client-side timeout
case exec: RegisterExecutor if exec.execId == timingoutExecutor =>
() // No reply to generate client-side timeout
case exec: RegisterExecutor
if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce =>
secondExecutorFailedOnce = true
callback.onFailure(failure)
case exec: RegisterExecutor
if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce =>
secondExecutorFailedOnce = true
callback.onFailure(failure)
case exec: RegisterExecutor if exec.execId == tryAgainExecutor =>
callback.onSuccess(success)
case exec: RegisterExecutor if exec.execId == tryAgainExecutor =>
callback.onSuccess(success)
case exec: RegisterExecutor
if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce =>
thirdExecutorFailedOnce = true
callback.onFailure(failure)
case exec: RegisterExecutor
if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce =>
thirdExecutorFailedOnce = true
callback.onFailure(failure)
case exec: RegisterExecutor if exec.execId == succeedingExecutor =>
callback.onSuccess(success)
case exec: RegisterExecutor if exec.execId == succeedingExecutor =>
callback.onSuccess(success)
}
}
}
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0)
val transCtx = new TransportContext(transConf, handler, true)
(transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port)
}
val candidatePort = RandomUtils.nextInt(1024, 65536)
val (server, shufflePort) = Utils.startServiceOnPort(candidatePort,
newShuffleServer, conf, "ShuffleServer")
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0)
conf.set(SHUFFLE_SERVICE_ENABLED.key, "true")
conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString)
conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40")
conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
var e = intercept[SparkException] {
makeBlockManager(8000, timingoutExecutor)
}.getMessage
assert(e.contains("TimeoutException"))
Utils.tryWithResource(new TransportContext(transConf, handler, true)) { transCtx =>
// a server which delays response 50ms and must try twice for success.
def newShuffleServer(port: Int): (TransportServer, Int) = {
(transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port)
}
conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
e = intercept[SparkException] {
makeBlockManager(8000, tryAgainExecutor)
}.getMessage
assert(e.contains(tryAgainMsg))
val candidatePort = RandomUtils.nextInt(1024, 65536)
val (server, shufflePort) = Utils.startServiceOnPort(candidatePort,
newShuffleServer, conf, "ShuffleServer")
conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2")
makeBlockManager(8000, succeedingExecutor)
server.close()
conf.set(SHUFFLE_SERVICE_ENABLED.key, "true")
conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString)
conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40")
conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
var e = intercept[SparkException] {
makeBlockManager(8000, timingoutExecutor)
}.getMessage
assert(e.contains("TimeoutException"))
conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
e = intercept[SparkException] {
makeBlockManager(8000, tryAgainExecutor)
}.getMessage
assert(e.contains(tryAgainMsg))
conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2")
makeBlockManager(8000, succeedingExecutor)
server.close()
}
}
test("fetch remote block to local disk if block size is larger than threshold") {