[SPARK-29398][CORE] Support dedicated thread pools for RPC endpoints

The current RPC backend in Spark supports single- and multi-threaded
message delivery to endpoints, but they all share the same underlying
thread pool. So an RPC endpoint that blocks a dispatcher thread can
negatively affect other endpoints.

This can be more pronounced with configurations that limit the number
of RPC dispatch threads based on configuration and / or running
environment. And exposing the RPC layer to other code (for example
with something like SPARK-29396) could make it easy to affect normal
Spark operation with a badly written RPC handler.

This change adds a new RPC endpoint type that tells the RPC env to
create dedicated dispatch threads, so that those effects are minimised.
Other endpoints will still need CPU to process their messages, but
they won't be able to actively block the dispatch thread of these
isolated endpoints.

As part of the change, I've changed the most important Spark endpoints
(the driver, executor and block manager endpoints) to be isolated from
others. This means a couple of extra threads are created on the driver
and executor for these endpoints.

Tested with existing unit tests, which hammer the RPC system extensively,
and also by running applications on a cluster (with a prototype of
SPARK-29396).

Closes #26059 from vanzin/SPARK-29398.

Authored-by: Marcelo Vanzin <vanzin@cloudera.com>
Signed-off-by: Imran Rashid <irashid@cloudera.com>
This commit is contained in:
Marcelo Vanzin 2019-10-17 13:14:32 -05:00 committed by Imran Rashid
parent f800fa3831
commit 2f0a38cb50
10 changed files with 296 additions and 120 deletions

View file

@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend(
userClassPath: Seq[URL],
env: SparkEnv,
resourcesFileOpt: Option[String])
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
extends IsolatedRpcEndpoint with ExecutorBackend with Logging {
private implicit val formats = DefaultFormats

View file

@ -146,3 +146,19 @@ private[spark] trait RpcEndpoint {
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
/**
* An endpoint that uses a dedicated thread pool for delivering messages.
*/
private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint {
/**
* How many threads to use for delivering messages. By default, use a single thread.
*
* Note that requesting more than one thread means that the endpoint should be able to handle
* messages arriving from many threads at once, and all the things that entails (including
* messages being delivered to the endpoint out of order).
*/
def threadCount(): Int = 1
}

View file

@ -17,20 +17,16 @@
package org.apache.spark.rpc.netty
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.concurrent.Promise
import scala.util.control.NonFatal
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
import org.apache.spark.util.ThreadUtils
/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {
private class EndpointData(
val name: String,
val endpoint: RpcEndpoint,
val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
}
private val endpoints: ConcurrentMap[String, EndpointData] =
new ConcurrentHashMap[String, EndpointData]
private val endpoints: ConcurrentMap[String, MessageLoop] =
new ConcurrentHashMap[String, MessageLoop]
private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]
private val shutdownLatch = new CountDownLatch(1)
private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores)
private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop = {
endpoint match {
case e: IsolatedRpcEndpoint =>
new DedicatedMessageLoop(name, e, this)
case _ =>
sharedLoop.register(name, endpoint)
sharedLoop
}
}
/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
val data = endpoints.get(name)
endpointRefs.put(data.endpoint, data.ref)
receivers.offer(data) // for the OnStart message
}
endpointRefs.put(endpoint, endpointRef)
endpointRef
}
@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
// Should be idempotent
private def unregisterRpcEndpoint(name: String): Unit = {
val data = endpoints.remove(name)
if (data != null) {
data.inbox.stop()
receivers.offer(data) // for the OnStop message
val loop = endpoints.remove(name)
if (loop != null) {
loop.unregister(name)
}
// Don't clean `endpointRefs` here because it's possible that some messages are being processed
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
val data = endpoints.get(endpointName)
val loop = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
} else if (data == null) {
} else if (loop == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
data.inbox.post(message)
receivers.offer(data)
loop.post(endpointName, message)
None
}
}
@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
}
stopped = true
}
// Stop all endpoints. This will queue all endpoints for processing by the message loops.
endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
// Enqueue a message that tells the message loops to stop.
receivers.offer(PoisonPill)
threadpool.shutdown()
var stopSharedLoop = false
endpoints.asScala.foreach { case (name, loop) =>
unregisterRpcEndpoint(name)
if (!loop.isInstanceOf[SharedMessageLoop]) {
loop.stop()
} else {
stopSharedLoop = true
}
}
if (stopSharedLoop) {
sharedLoop.stop()
}
shutdownLatch.countDown()
}
def awaitTermination(): Unit = {
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
shutdownLatch.await()
}
/**
@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
def verify(name: String): Boolean = {
endpoints.containsKey(name)
}
private def getNumOfThreads(conf: SparkConf): Int = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
.getOrElse(math.max(2, availableCores))
conf.get(EXECUTOR_ID).map { id =>
val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
}.getOrElse(modNumThreads)
}
/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val numThreads = getNumOfThreads(nettyEnv.conf)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case _: InterruptedException => // exit
case t: Throwable =>
try {
// Re-submit a MessageLoop so that Dispatcher will still work if
// UncaughtExceptionHandler decides to not kill JVM.
threadpool.execute(new MessageLoop)
} finally {
throw t
}
}
}
}
/** A poison endpoint that indicates MessageLoop should exit its message loop. */
private val PoisonPill = new EndpointData(null, null, null)
}

View file

@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA
/**
* An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
*/
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
val endpoint: RpcEndpoint)
private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
extends Logging {
inbox => // Give this an alias so we can use it more clearly in closures.
@ -195,7 +193,7 @@ private[netty] class Inbox(
* Exposed for testing.
*/
protected def onDrop(message: InboxMessage): Unit = {
logWarning(s"Drop $message because $endpointRef is stopped")
logWarning(s"Drop $message because endpoint $endpointName is stopped")
}
/**

View file

@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rpc.netty
import java.util.concurrent._
import scala.util.control.NonFatal
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.internal.config.Network._
import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcEndpoint}
import org.apache.spark.util.ThreadUtils
/**
* A message loop used by [[Dispatcher]] to deliver messages to endpoints.
*/
private sealed abstract class MessageLoop(dispatcher: Dispatcher) extends Logging {
// List of inboxes with pending messages, to be processed by the message loop.
private val active = new LinkedBlockingQueue[Inbox]()
// Message loop task; should be run in all threads of the message loop's pool.
protected val receiveLoopRunnable = new Runnable() {
override def run(): Unit = receiveLoop()
}
protected val threadpool: ExecutorService
private var stopped = false
def post(endpointName: String, message: InboxMessage): Unit
def unregister(name: String): Unit
def stop(): Unit = {
synchronized {
if (!stopped) {
setActive(MessageLoop.PoisonPill)
threadpool.shutdown()
stopped = true
}
}
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
}
protected final def setActive(inbox: Inbox): Unit = active.offer(inbox)
private def receiveLoop(): Unit = {
try {
while (true) {
try {
val inbox = active.take()
if (inbox == MessageLoop.PoisonPill) {
// Put PoisonPill back so that other threads can see it.
setActive(MessageLoop.PoisonPill)
return
}
inbox.process(dispatcher)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case _: InterruptedException => // exit
case t: Throwable =>
try {
// Re-submit a receive task so that message delivery will still work if
// UncaughtExceptionHandler decides to not kill JVM.
threadpool.execute(receiveLoopRunnable)
} finally {
throw t
}
}
}
}
private object MessageLoop {
/** A poison inbox that indicates the message loop should stop processing messages. */
val PoisonPill = new Inbox(null, null)
}
/**
* A message loop that serves multiple RPC endpoints, using a shared thread pool.
*/
private class SharedMessageLoop(
conf: SparkConf,
dispatcher: Dispatcher,
numUsableCores: Int)
extends MessageLoop(dispatcher) {
private val endpoints = new ConcurrentHashMap[String, Inbox]()
private def getNumOfThreads(conf: SparkConf): Int = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
.getOrElse(math.max(2, availableCores))
conf.get(EXECUTOR_ID).map { id =>
val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
}.getOrElse(modNumThreads)
}
/** Thread pool used for dispatching messages. */
override protected val threadpool: ThreadPoolExecutor = {
val numThreads = getNumOfThreads(conf)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(receiveLoopRunnable)
}
pool
}
override def post(endpointName: String, message: InboxMessage): Unit = {
val inbox = endpoints.get(endpointName)
inbox.post(message)
setActive(inbox)
}
override def unregister(name: String): Unit = {
val inbox = endpoints.remove(name)
if (inbox != null) {
inbox.stop()
// Mark active to handle the OnStop message.
setActive(inbox)
}
}
def register(name: String, endpoint: RpcEndpoint): Unit = {
val inbox = new Inbox(name, endpoint)
endpoints.put(name, inbox)
// Mark active to handle the OnStart message.
setActive(inbox)
}
}
/**
* A message loop that is dedicated to a single RPC endpoint.
*/
private class DedicatedMessageLoop(
name: String,
endpoint: IsolatedRpcEndpoint,
dispatcher: Dispatcher)
extends MessageLoop(dispatcher) {
private val inbox = new Inbox(name, endpoint)
override protected val threadpool = if (endpoint.threadCount() > 1) {
ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", endpoint.threadCount())
} else {
ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name")
}
(1 to endpoint.threadCount()).foreach { _ =>
threadpool.submit(receiveLoopRunnable)
}
// Mark active to handle the OnStart message.
setActive(inbox)
override def post(endpointName: String, message: InboxMessage): Unit = {
require(endpointName == name)
inbox.post(message)
setActive(inbox)
}
override def unregister(endpointName: String): Unit = synchronized {
require(endpointName == name)
inbox.stop()
// Mark active to handle the OnStop message.
setActive(inbox)
setActive(MessageLoop.PoisonPill)
threadpool.shutdown()
}
}

View file

@ -111,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
private val reviveThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
class DriverEndpoint extends ThreadSafeRpcEndpoint with Logging {
class DriverEndpoint extends IsolatedRpcEndpoint with Logging {
override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv

View file

@ -30,7 +30,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.network.shuffle.ExternalBlockStoreClient
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
@ -46,7 +46,7 @@ class BlockManagerMasterEndpoint(
conf: SparkConf,
listenerBus: LiveListenerBus,
externalBlockStoreClient: Option[ExternalBlockStoreClient])
extends ThreadSafeRpcEndpoint with Logging {
extends IsolatedRpcEndpoint with Logging {
// Mapping from block manager id to the block manager's information.
private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]

View file

@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}
@ -34,7 +34,7 @@ class BlockManagerSlaveEndpoint(
override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
extends ThreadSafeRpcEndpoint with Logging {
extends IsolatedRpcEndpoint with Logging {
private val asyncThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100)

View file

@ -36,7 +36,6 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Network
import org.apache.spark.util.{ThreadUtils, Utils}
/**
@ -954,6 +953,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}
test("isolated endpoints") {
val latch = new CountDownLatch(1)
val singleThreadedEnv = createRpcEnv(
new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0)
try {
val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new IsolatedRpcEndpoint {
override val rpcEnv: RpcEnv = singleThreadedEnv
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case m =>
latch.await()
context.reply(m)
}
})
val nonBlockingEndpoint = singleThreadedEnv.setupEndpoint("non-blocking", new RpcEndpoint {
override val rpcEnv: RpcEnv = singleThreadedEnv
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case m => context.reply(m)
}
})
val to = new RpcTimeout(5.seconds, "test-timeout")
val blockingFuture = blockingEndpoint.ask[String]("hi", to)
assert(nonBlockingEndpoint.askSync[String]("hello", to) === "hello")
latch.countDown()
assert(ThreadUtils.awaitResult(blockingFuture, 5.seconds) === "hi")
} finally {
latch.countDown()
singleThreadedEnv.shutdown()
}
}
}
class UnserializableClass

View file

@ -29,12 +29,9 @@ class InboxSuite extends SparkFunSuite {
test("post") {
val endpoint = new TestRpcEndpoint
val endpointRef = mock(classOf[NettyRpcEndpointRef])
when(endpointRef.name).thenReturn("hello")
val dispatcher = mock(classOf[Dispatcher])
val inbox = new Inbox(endpointRef, endpoint)
val inbox = new Inbox("name", endpoint)
val message = OneWayMessage(null, "hi")
inbox.post(message)
inbox.process(dispatcher)
@ -51,10 +48,9 @@ class InboxSuite extends SparkFunSuite {
test("post: with reply") {
val endpoint = new TestRpcEndpoint
val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val inbox = new Inbox(endpointRef, endpoint)
val inbox = new Inbox("name", endpoint)
val message = RpcMessage(null, "hi", null)
inbox.post(message)
inbox.process(dispatcher)
@ -65,13 +61,10 @@ class InboxSuite extends SparkFunSuite {
test("post: multiple threads") {
val endpoint = new TestRpcEndpoint
val endpointRef = mock(classOf[NettyRpcEndpointRef])
when(endpointRef.name).thenReturn("hello")
val dispatcher = mock(classOf[Dispatcher])
val numDroppedMessages = new AtomicInteger(0)
val inbox = new Inbox(endpointRef, endpoint) {
val inbox = new Inbox("name", endpoint) {
override def onDrop(message: InboxMessage): Unit = {
numDroppedMessages.incrementAndGet()
}
@ -107,12 +100,10 @@ class InboxSuite extends SparkFunSuite {
test("post: Associated") {
val endpoint = new TestRpcEndpoint
val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
val inbox = new Inbox(endpointRef, endpoint)
val inbox = new Inbox("name", endpoint)
inbox.post(RemoteProcessConnected(remoteAddress))
inbox.process(dispatcher)
@ -121,12 +112,11 @@ class InboxSuite extends SparkFunSuite {
test("post: Disassociated") {
val endpoint = new TestRpcEndpoint
val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
val inbox = new Inbox(endpointRef, endpoint)
val inbox = new Inbox("name", endpoint)
inbox.post(RemoteProcessDisconnected(remoteAddress))
inbox.process(dispatcher)
@ -135,13 +125,12 @@ class InboxSuite extends SparkFunSuite {
test("post: AssociationError") {
val endpoint = new TestRpcEndpoint
val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
val cause = new RuntimeException("Oops")
val inbox = new Inbox(endpointRef, endpoint)
val inbox = new Inbox("name", endpoint)
inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
inbox.process(dispatcher)