[SPARK-29398][CORE] Support dedicated thread pools for RPC endpoints
The current RPC backend in Spark supports single- and multi-threaded message delivery to endpoints, but they all share the same underlying thread pool. So an RPC endpoint that blocks a dispatcher thread can negatively affect other endpoints. This can be more pronounced with configurations that limit the number of RPC dispatch threads based on configuration and / or running environment. And exposing the RPC layer to other code (for example with something like SPARK-29396) could make it easy to affect normal Spark operation with a badly written RPC handler. This change adds a new RPC endpoint type that tells the RPC env to create dedicated dispatch threads, so that those effects are minimised. Other endpoints will still need CPU to process their messages, but they won't be able to actively block the dispatch thread of these isolated endpoints. As part of the change, I've changed the most important Spark endpoints (the driver, executor and block manager endpoints) to be isolated from others. This means a couple of extra threads are created on the driver and executor for these endpoints. Tested with existing unit tests, which hammer the RPC system extensively, and also by running applications on a cluster (with a prototype of SPARK-29396). Closes #26059 from vanzin/SPARK-29398. Authored-by: Marcelo Vanzin <vanzin@cloudera.com> Signed-off-by: Imran Rashid <irashid@cloudera.com>
This commit is contained in:
parent
f800fa3831
commit
2f0a38cb50
|
@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend(
|
|||
userClassPath: Seq[URL],
|
||||
env: SparkEnv,
|
||||
resourcesFileOpt: Option[String])
|
||||
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
|
||||
extends IsolatedRpcEndpoint with ExecutorBackend with Logging {
|
||||
|
||||
private implicit val formats = DefaultFormats
|
||||
|
||||
|
|
|
@ -146,3 +146,19 @@ private[spark] trait RpcEndpoint {
|
|||
* [[ThreadSafeRpcEndpoint]] for different messages.
|
||||
*/
|
||||
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
|
||||
|
||||
/**
|
||||
* An endpoint that uses a dedicated thread pool for delivering messages.
|
||||
*/
|
||||
private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint {
|
||||
|
||||
/**
|
||||
* How many threads to use for delivering messages. By default, use a single thread.
|
||||
*
|
||||
* Note that requesting more than one thread means that the endpoint should be able to handle
|
||||
* messages arriving from many threads at once, and all the things that entails (including
|
||||
* messages being delivered to the endpoint out of order).
|
||||
*/
|
||||
def threadCount(): Int = 1
|
||||
|
||||
}
|
||||
|
|
|
@ -17,20 +17,16 @@
|
|||
|
||||
package org.apache.spark.rpc.netty
|
||||
|
||||
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
|
||||
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
|
||||
import javax.annotation.concurrent.GuardedBy
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.concurrent.Promise
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext, SparkException}
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.EXECUTOR_ID
|
||||
import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS
|
||||
import org.apache.spark.network.client.RpcResponseCallback
|
||||
import org.apache.spark.rpc._
|
||||
import org.apache.spark.util.ThreadUtils
|
||||
|
||||
/**
|
||||
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
|
||||
|
@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils
|
|||
*/
|
||||
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {
|
||||
|
||||
private class EndpointData(
|
||||
val name: String,
|
||||
val endpoint: RpcEndpoint,
|
||||
val ref: NettyRpcEndpointRef) {
|
||||
val inbox = new Inbox(ref, endpoint)
|
||||
}
|
||||
|
||||
private val endpoints: ConcurrentMap[String, EndpointData] =
|
||||
new ConcurrentHashMap[String, EndpointData]
|
||||
private val endpoints: ConcurrentMap[String, MessageLoop] =
|
||||
new ConcurrentHashMap[String, MessageLoop]
|
||||
private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
|
||||
new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
|
||||
|
||||
// Track the receivers whose inboxes may contain messages.
|
||||
private val receivers = new LinkedBlockingQueue[EndpointData]
|
||||
private val shutdownLatch = new CountDownLatch(1)
|
||||
private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores)
|
||||
|
||||
private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop = {
|
||||
endpoint match {
|
||||
case e: IsolatedRpcEndpoint =>
|
||||
new DedicatedMessageLoop(name, e, this)
|
||||
case _ =>
|
||||
sharedLoop.register(name, endpoint)
|
||||
sharedLoop
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
|
||||
|
@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
|
|||
if (stopped) {
|
||||
throw new IllegalStateException("RpcEnv has been stopped")
|
||||
}
|
||||
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
|
||||
if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) {
|
||||
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
|
||||
}
|
||||
val data = endpoints.get(name)
|
||||
endpointRefs.put(data.endpoint, data.ref)
|
||||
receivers.offer(data) // for the OnStart message
|
||||
}
|
||||
endpointRefs.put(endpoint, endpointRef)
|
||||
endpointRef
|
||||
}
|
||||
|
||||
|
@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
|
|||
|
||||
// Should be idempotent
|
||||
private def unregisterRpcEndpoint(name: String): Unit = {
|
||||
val data = endpoints.remove(name)
|
||||
if (data != null) {
|
||||
data.inbox.stop()
|
||||
receivers.offer(data) // for the OnStop message
|
||||
val loop = endpoints.remove(name)
|
||||
if (loop != null) {
|
||||
loop.unregister(name)
|
||||
}
|
||||
// Don't clean `endpointRefs` here because it's possible that some messages are being processed
|
||||
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
|
||||
|
@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
|
|||
message: InboxMessage,
|
||||
callbackIfStopped: (Exception) => Unit): Unit = {
|
||||
val error = synchronized {
|
||||
val data = endpoints.get(endpointName)
|
||||
val loop = endpoints.get(endpointName)
|
||||
if (stopped) {
|
||||
Some(new RpcEnvStoppedException())
|
||||
} else if (data == null) {
|
||||
} else if (loop == null) {
|
||||
Some(new SparkException(s"Could not find $endpointName."))
|
||||
} else {
|
||||
data.inbox.post(message)
|
||||
receivers.offer(data)
|
||||
loop.post(endpointName, message)
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
|
|||
}
|
||||
stopped = true
|
||||
}
|
||||
// Stop all endpoints. This will queue all endpoints for processing by the message loops.
|
||||
endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
|
||||
// Enqueue a message that tells the message loops to stop.
|
||||
receivers.offer(PoisonPill)
|
||||
threadpool.shutdown()
|
||||
var stopSharedLoop = false
|
||||
endpoints.asScala.foreach { case (name, loop) =>
|
||||
unregisterRpcEndpoint(name)
|
||||
if (!loop.isInstanceOf[SharedMessageLoop]) {
|
||||
loop.stop()
|
||||
} else {
|
||||
stopSharedLoop = true
|
||||
}
|
||||
}
|
||||
if (stopSharedLoop) {
|
||||
sharedLoop.stop()
|
||||
}
|
||||
shutdownLatch.countDown()
|
||||
}
|
||||
|
||||
def awaitTermination(): Unit = {
|
||||
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
|
||||
shutdownLatch.await()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
|
|||
def verify(name: String): Boolean = {
|
||||
endpoints.containsKey(name)
|
||||
}
|
||||
|
||||
private def getNumOfThreads(conf: SparkConf): Int = {
|
||||
val availableCores =
|
||||
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
|
||||
|
||||
val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
|
||||
.getOrElse(math.max(2, availableCores))
|
||||
|
||||
conf.get(EXECUTOR_ID).map { id =>
|
||||
val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
|
||||
conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
|
||||
}.getOrElse(modNumThreads)
|
||||
}
|
||||
|
||||
/** Thread pool used for dispatching messages. */
|
||||
private val threadpool: ThreadPoolExecutor = {
|
||||
val numThreads = getNumOfThreads(nettyEnv.conf)
|
||||
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
|
||||
for (i <- 0 until numThreads) {
|
||||
pool.execute(new MessageLoop)
|
||||
}
|
||||
pool
|
||||
}
|
||||
|
||||
/** Message loop used for dispatching messages. */
|
||||
private class MessageLoop extends Runnable {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
while (true) {
|
||||
try {
|
||||
val data = receivers.take()
|
||||
if (data == PoisonPill) {
|
||||
// Put PoisonPill back so that other MessageLoops can see it.
|
||||
receivers.offer(PoisonPill)
|
||||
return
|
||||
}
|
||||
data.inbox.process(Dispatcher.this)
|
||||
} catch {
|
||||
case NonFatal(e) => logError(e.getMessage, e)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case _: InterruptedException => // exit
|
||||
case t: Throwable =>
|
||||
try {
|
||||
// Re-submit a MessageLoop so that Dispatcher will still work if
|
||||
// UncaughtExceptionHandler decides to not kill JVM.
|
||||
threadpool.execute(new MessageLoop)
|
||||
} finally {
|
||||
throw t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** A poison endpoint that indicates MessageLoop should exit its message loop. */
|
||||
private val PoisonPill = new EndpointData(null, null, null)
|
||||
}
|
||||
|
|
|
@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA
|
|||
/**
|
||||
* An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
|
||||
*/
|
||||
private[netty] class Inbox(
|
||||
val endpointRef: NettyRpcEndpointRef,
|
||||
val endpoint: RpcEndpoint)
|
||||
private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
|
||||
extends Logging {
|
||||
|
||||
inbox => // Give this an alias so we can use it more clearly in closures.
|
||||
|
@ -195,7 +193,7 @@ private[netty] class Inbox(
|
|||
* Exposed for testing.
|
||||
*/
|
||||
protected def onDrop(message: InboxMessage): Unit = {
|
||||
logWarning(s"Drop $message because $endpointRef is stopped")
|
||||
logWarning(s"Drop $message because endpoint $endpointName is stopped")
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
194
core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
Normal file
194
core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
Normal file
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.rpc.netty
|
||||
|
||||
import java.util.concurrent._
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.EXECUTOR_ID
|
||||
import org.apache.spark.internal.config.Network._
|
||||
import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcEndpoint}
|
||||
import org.apache.spark.util.ThreadUtils
|
||||
|
||||
/**
|
||||
* A message loop used by [[Dispatcher]] to deliver messages to endpoints.
|
||||
*/
|
||||
private sealed abstract class MessageLoop(dispatcher: Dispatcher) extends Logging {
|
||||
|
||||
// List of inboxes with pending messages, to be processed by the message loop.
|
||||
private val active = new LinkedBlockingQueue[Inbox]()
|
||||
|
||||
// Message loop task; should be run in all threads of the message loop's pool.
|
||||
protected val receiveLoopRunnable = new Runnable() {
|
||||
override def run(): Unit = receiveLoop()
|
||||
}
|
||||
|
||||
protected val threadpool: ExecutorService
|
||||
|
||||
private var stopped = false
|
||||
|
||||
def post(endpointName: String, message: InboxMessage): Unit
|
||||
|
||||
def unregister(name: String): Unit
|
||||
|
||||
def stop(): Unit = {
|
||||
synchronized {
|
||||
if (!stopped) {
|
||||
setActive(MessageLoop.PoisonPill)
|
||||
threadpool.shutdown()
|
||||
stopped = true
|
||||
}
|
||||
}
|
||||
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
|
||||
protected final def setActive(inbox: Inbox): Unit = active.offer(inbox)
|
||||
|
||||
private def receiveLoop(): Unit = {
|
||||
try {
|
||||
while (true) {
|
||||
try {
|
||||
val inbox = active.take()
|
||||
if (inbox == MessageLoop.PoisonPill) {
|
||||
// Put PoisonPill back so that other threads can see it.
|
||||
setActive(MessageLoop.PoisonPill)
|
||||
return
|
||||
}
|
||||
inbox.process(dispatcher)
|
||||
} catch {
|
||||
case NonFatal(e) => logError(e.getMessage, e)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case _: InterruptedException => // exit
|
||||
case t: Throwable =>
|
||||
try {
|
||||
// Re-submit a receive task so that message delivery will still work if
|
||||
// UncaughtExceptionHandler decides to not kill JVM.
|
||||
threadpool.execute(receiveLoopRunnable)
|
||||
} finally {
|
||||
throw t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private object MessageLoop {
|
||||
/** A poison inbox that indicates the message loop should stop processing messages. */
|
||||
val PoisonPill = new Inbox(null, null)
|
||||
}
|
||||
|
||||
/**
|
||||
* A message loop that serves multiple RPC endpoints, using a shared thread pool.
|
||||
*/
|
||||
private class SharedMessageLoop(
|
||||
conf: SparkConf,
|
||||
dispatcher: Dispatcher,
|
||||
numUsableCores: Int)
|
||||
extends MessageLoop(dispatcher) {
|
||||
|
||||
private val endpoints = new ConcurrentHashMap[String, Inbox]()
|
||||
|
||||
private def getNumOfThreads(conf: SparkConf): Int = {
|
||||
val availableCores =
|
||||
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
|
||||
|
||||
val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
|
||||
.getOrElse(math.max(2, availableCores))
|
||||
|
||||
conf.get(EXECUTOR_ID).map { id =>
|
||||
val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
|
||||
conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
|
||||
}.getOrElse(modNumThreads)
|
||||
}
|
||||
|
||||
/** Thread pool used for dispatching messages. */
|
||||
override protected val threadpool: ThreadPoolExecutor = {
|
||||
val numThreads = getNumOfThreads(conf)
|
||||
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
|
||||
for (i <- 0 until numThreads) {
|
||||
pool.execute(receiveLoopRunnable)
|
||||
}
|
||||
pool
|
||||
}
|
||||
|
||||
override def post(endpointName: String, message: InboxMessage): Unit = {
|
||||
val inbox = endpoints.get(endpointName)
|
||||
inbox.post(message)
|
||||
setActive(inbox)
|
||||
}
|
||||
|
||||
override def unregister(name: String): Unit = {
|
||||
val inbox = endpoints.remove(name)
|
||||
if (inbox != null) {
|
||||
inbox.stop()
|
||||
// Mark active to handle the OnStop message.
|
||||
setActive(inbox)
|
||||
}
|
||||
}
|
||||
|
||||
def register(name: String, endpoint: RpcEndpoint): Unit = {
|
||||
val inbox = new Inbox(name, endpoint)
|
||||
endpoints.put(name, inbox)
|
||||
// Mark active to handle the OnStart message.
|
||||
setActive(inbox)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A message loop that is dedicated to a single RPC endpoint.
|
||||
*/
|
||||
private class DedicatedMessageLoop(
|
||||
name: String,
|
||||
endpoint: IsolatedRpcEndpoint,
|
||||
dispatcher: Dispatcher)
|
||||
extends MessageLoop(dispatcher) {
|
||||
|
||||
private val inbox = new Inbox(name, endpoint)
|
||||
|
||||
override protected val threadpool = if (endpoint.threadCount() > 1) {
|
||||
ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", endpoint.threadCount())
|
||||
} else {
|
||||
ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name")
|
||||
}
|
||||
|
||||
(1 to endpoint.threadCount()).foreach { _ =>
|
||||
threadpool.submit(receiveLoopRunnable)
|
||||
}
|
||||
|
||||
// Mark active to handle the OnStart message.
|
||||
setActive(inbox)
|
||||
|
||||
override def post(endpointName: String, message: InboxMessage): Unit = {
|
||||
require(endpointName == name)
|
||||
inbox.post(message)
|
||||
setActive(inbox)
|
||||
}
|
||||
|
||||
override def unregister(endpointName: String): Unit = synchronized {
|
||||
require(endpointName == name)
|
||||
inbox.stop()
|
||||
// Mark active to handle the OnStop message.
|
||||
setActive(inbox)
|
||||
setActive(MessageLoop.PoisonPill)
|
||||
threadpool.shutdown()
|
||||
}
|
||||
}
|
|
@ -111,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
|
|||
private val reviveThread =
|
||||
ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
|
||||
|
||||
class DriverEndpoint extends ThreadSafeRpcEndpoint with Logging {
|
||||
class DriverEndpoint extends IsolatedRpcEndpoint with Logging {
|
||||
|
||||
override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.SparkConf
|
|||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.internal.{config, Logging}
|
||||
import org.apache.spark.network.shuffle.ExternalBlockStoreClient
|
||||
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
|
||||
import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv}
|
||||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.storage.BlockManagerMessages._
|
||||
import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
|
||||
|
@ -46,7 +46,7 @@ class BlockManagerMasterEndpoint(
|
|||
conf: SparkConf,
|
||||
listenerBus: LiveListenerBus,
|
||||
externalBlockStoreClient: Option[ExternalBlockStoreClient])
|
||||
extends ThreadSafeRpcEndpoint with Logging {
|
||||
extends IsolatedRpcEndpoint with Logging {
|
||||
|
||||
// Mapping from block manager id to the block manager's information.
|
||||
private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
|
||||
|
|
|
@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
|
|||
|
||||
import org.apache.spark.{MapOutputTracker, SparkEnv}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
|
||||
import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
|
||||
import org.apache.spark.storage.BlockManagerMessages._
|
||||
import org.apache.spark.util.{ThreadUtils, Utils}
|
||||
|
||||
|
@ -34,7 +34,7 @@ class BlockManagerSlaveEndpoint(
|
|||
override val rpcEnv: RpcEnv,
|
||||
blockManager: BlockManager,
|
||||
mapOutputTracker: MapOutputTracker)
|
||||
extends ThreadSafeRpcEndpoint with Logging {
|
||||
extends IsolatedRpcEndpoint with Logging {
|
||||
|
||||
private val asyncThreadPool =
|
||||
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100)
|
||||
|
|
|
@ -36,7 +36,6 @@ import org.scalatest.concurrent.Eventually._
|
|||
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite}
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.internal.config._
|
||||
import org.apache.spark.internal.config.Network
|
||||
import org.apache.spark.util.{ThreadUtils, Utils}
|
||||
|
||||
/**
|
||||
|
@ -954,6 +953,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
|
|||
verify(endpoint, never()).onDisconnected(any())
|
||||
verify(endpoint, never()).onNetworkError(any(), any())
|
||||
}
|
||||
|
||||
test("isolated endpoints") {
|
||||
val latch = new CountDownLatch(1)
|
||||
val singleThreadedEnv = createRpcEnv(
|
||||
new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0)
|
||||
try {
|
||||
val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new IsolatedRpcEndpoint {
|
||||
override val rpcEnv: RpcEnv = singleThreadedEnv
|
||||
|
||||
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
|
||||
case m =>
|
||||
latch.await()
|
||||
context.reply(m)
|
||||
}
|
||||
})
|
||||
|
||||
val nonBlockingEndpoint = singleThreadedEnv.setupEndpoint("non-blocking", new RpcEndpoint {
|
||||
override val rpcEnv: RpcEnv = singleThreadedEnv
|
||||
|
||||
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
|
||||
case m => context.reply(m)
|
||||
}
|
||||
})
|
||||
|
||||
val to = new RpcTimeout(5.seconds, "test-timeout")
|
||||
val blockingFuture = blockingEndpoint.ask[String]("hi", to)
|
||||
assert(nonBlockingEndpoint.askSync[String]("hello", to) === "hello")
|
||||
latch.countDown()
|
||||
assert(ThreadUtils.awaitResult(blockingFuture, 5.seconds) === "hi")
|
||||
} finally {
|
||||
latch.countDown()
|
||||
singleThreadedEnv.shutdown()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class UnserializableClass
|
||||
|
|
|
@ -29,12 +29,9 @@ class InboxSuite extends SparkFunSuite {
|
|||
|
||||
test("post") {
|
||||
val endpoint = new TestRpcEndpoint
|
||||
val endpointRef = mock(classOf[NettyRpcEndpointRef])
|
||||
when(endpointRef.name).thenReturn("hello")
|
||||
|
||||
val dispatcher = mock(classOf[Dispatcher])
|
||||
|
||||
val inbox = new Inbox(endpointRef, endpoint)
|
||||
val inbox = new Inbox("name", endpoint)
|
||||
val message = OneWayMessage(null, "hi")
|
||||
inbox.post(message)
|
||||
inbox.process(dispatcher)
|
||||
|
@ -51,10 +48,9 @@ class InboxSuite extends SparkFunSuite {
|
|||
|
||||
test("post: with reply") {
|
||||
val endpoint = new TestRpcEndpoint
|
||||
val endpointRef = mock(classOf[NettyRpcEndpointRef])
|
||||
val dispatcher = mock(classOf[Dispatcher])
|
||||
|
||||
val inbox = new Inbox(endpointRef, endpoint)
|
||||
val inbox = new Inbox("name", endpoint)
|
||||
val message = RpcMessage(null, "hi", null)
|
||||
inbox.post(message)
|
||||
inbox.process(dispatcher)
|
||||
|
@ -65,13 +61,10 @@ class InboxSuite extends SparkFunSuite {
|
|||
|
||||
test("post: multiple threads") {
|
||||
val endpoint = new TestRpcEndpoint
|
||||
val endpointRef = mock(classOf[NettyRpcEndpointRef])
|
||||
when(endpointRef.name).thenReturn("hello")
|
||||
|
||||
val dispatcher = mock(classOf[Dispatcher])
|
||||
|
||||
val numDroppedMessages = new AtomicInteger(0)
|
||||
val inbox = new Inbox(endpointRef, endpoint) {
|
||||
val inbox = new Inbox("name", endpoint) {
|
||||
override def onDrop(message: InboxMessage): Unit = {
|
||||
numDroppedMessages.incrementAndGet()
|
||||
}
|
||||
|
@ -107,12 +100,10 @@ class InboxSuite extends SparkFunSuite {
|
|||
|
||||
test("post: Associated") {
|
||||
val endpoint = new TestRpcEndpoint
|
||||
val endpointRef = mock(classOf[NettyRpcEndpointRef])
|
||||
val dispatcher = mock(classOf[Dispatcher])
|
||||
|
||||
val remoteAddress = RpcAddress("localhost", 11111)
|
||||
|
||||
val inbox = new Inbox(endpointRef, endpoint)
|
||||
val inbox = new Inbox("name", endpoint)
|
||||
inbox.post(RemoteProcessConnected(remoteAddress))
|
||||
inbox.process(dispatcher)
|
||||
|
||||
|
@ -121,12 +112,11 @@ class InboxSuite extends SparkFunSuite {
|
|||
|
||||
test("post: Disassociated") {
|
||||
val endpoint = new TestRpcEndpoint
|
||||
val endpointRef = mock(classOf[NettyRpcEndpointRef])
|
||||
val dispatcher = mock(classOf[Dispatcher])
|
||||
|
||||
val remoteAddress = RpcAddress("localhost", 11111)
|
||||
|
||||
val inbox = new Inbox(endpointRef, endpoint)
|
||||
val inbox = new Inbox("name", endpoint)
|
||||
inbox.post(RemoteProcessDisconnected(remoteAddress))
|
||||
inbox.process(dispatcher)
|
||||
|
||||
|
@ -135,13 +125,12 @@ class InboxSuite extends SparkFunSuite {
|
|||
|
||||
test("post: AssociationError") {
|
||||
val endpoint = new TestRpcEndpoint
|
||||
val endpointRef = mock(classOf[NettyRpcEndpointRef])
|
||||
val dispatcher = mock(classOf[Dispatcher])
|
||||
|
||||
val remoteAddress = RpcAddress("localhost", 11111)
|
||||
val cause = new RuntimeException("Oops")
|
||||
|
||||
val inbox = new Inbox(endpointRef, endpoint)
|
||||
val inbox = new Inbox("name", endpoint)
|
||||
inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
|
||||
inbox.process(dispatcher)
|
||||
|
||||
|
|
Loading…
Reference in a new issue