[SPARK-7655][Core][SQL] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'
Because both `AkkaRpcEndpointRef.ask` and `BroadcastHashJoin` uses `scala.concurrent.ExecutionContext.Implicits.global`. However, because the tasks in `BroadcastHashJoin` are usually long-running tasks, which will occupy all threads in `global`. Then `ask` cannot get a chance to process the replies.
For `ask`, actually the tasks are very simple, so we can use `MoreExecutors.sameThreadExecutor()`. For `BroadcastHashJoin`, it's better to use `ThreadUtils.newDaemonCachedThreadPool`.
Author: zsxwing <zsxwing@gmail.com>
Closes #6200 from zsxwing/SPARK-7655-2 and squashes the following commits:
cfdc605 [zsxwing] Remove redundant imort and minor doc fix
cf83153 [zsxwing] Add "sameThread" and "newDaemonCachedThreadPool with maxThreadNumber" to ThreadUtils
08ad0ee [zsxwing] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'
(cherry picked from commit 47e7ffe36b
)
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
parent
e7607e5cbc
commit
ad5b0b1ce2
|
@ -29,9 +29,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
|
||||||
import akka.event.Logging.Error
|
import akka.event.Logging.Error
|
||||||
import akka.pattern.{ask => akkaAsk}
|
import akka.pattern.{ask => akkaAsk}
|
||||||
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
|
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
|
||||||
|
import com.google.common.util.concurrent.MoreExecutors
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, Logging, SparkConf}
|
import org.apache.spark.{SparkException, Logging, SparkConf}
|
||||||
import org.apache.spark.rpc._
|
import org.apache.spark.rpc._
|
||||||
import org.apache.spark.util.{ActorLogReceive, AkkaUtils}
|
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A RpcEnv implementation based on Akka.
|
* A RpcEnv implementation based on Akka.
|
||||||
|
@ -294,8 +296,8 @@ private[akka] class AkkaRpcEndpointRef(
|
||||||
}
|
}
|
||||||
|
|
||||||
override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
|
override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
|
||||||
import scala.concurrent.ExecutionContext.Implicits.global
|
|
||||||
actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
|
actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
|
||||||
|
// The function will run in the calling thread, so it should be short and never block.
|
||||||
case msg @ AkkaMessage(message, reply) =>
|
case msg @ AkkaMessage(message, reply) =>
|
||||||
if (reply) {
|
if (reply) {
|
||||||
logError(s"Receive $msg but the sender cannot reply")
|
logError(s"Receive $msg but the sender cannot reply")
|
||||||
|
@ -305,7 +307,7 @@ private[akka] class AkkaRpcEndpointRef(
|
||||||
}
|
}
|
||||||
case AkkaFailure(e) =>
|
case AkkaFailure(e) =>
|
||||||
Future.failed(e)
|
Future.failed(e)
|
||||||
}.mapTo[T]
|
}(ThreadUtils.sameThread).mapTo[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
|
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
|
||||||
|
|
|
@ -20,10 +20,22 @@ package org.apache.spark.util
|
||||||
|
|
||||||
import java.util.concurrent._
|
import java.util.concurrent._
|
||||||
|
|
||||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
|
||||||
|
|
||||||
|
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
|
||||||
|
|
||||||
private[spark] object ThreadUtils {
|
private[spark] object ThreadUtils {
|
||||||
|
|
||||||
|
private val sameThreadExecutionContext =
|
||||||
|
ExecutionContext.fromExecutorService(MoreExecutors.sameThreadExecutor())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An `ExecutionContextExecutor` that runs each task in the thread that invokes `execute/submit`.
|
||||||
|
* The caller should make sure the tasks running in this `ExecutionContextExecutor` are short and
|
||||||
|
* never block.
|
||||||
|
*/
|
||||||
|
def sameThread: ExecutionContextExecutor = sameThreadExecutionContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a thread factory that names threads with a prefix and also sets the threads to daemon.
|
* Create a thread factory that names threads with a prefix and also sets the threads to daemon.
|
||||||
*/
|
*/
|
||||||
|
@ -40,6 +52,16 @@ private[spark] object ThreadUtils {
|
||||||
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
|
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
|
||||||
|
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
|
||||||
|
*/
|
||||||
|
def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = {
|
||||||
|
val threadFactory = namedThreadFactory(prefix)
|
||||||
|
new ThreadPoolExecutor(
|
||||||
|
0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
|
* Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
|
||||||
* unique, sequentially assigned integer.
|
* unique, sequentially assigned integer.
|
||||||
|
|
|
@ -20,6 +20,9 @@ package org.apache.spark.util
|
||||||
|
|
||||||
import java.util.concurrent.{CountDownLatch, TimeUnit}
|
import java.util.concurrent.{CountDownLatch, TimeUnit}
|
||||||
|
|
||||||
|
import scala.concurrent.{Await, Future}
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class ThreadUtilsSuite extends FunSuite {
|
class ThreadUtilsSuite extends FunSuite {
|
||||||
|
@ -54,4 +57,13 @@ class ThreadUtilsSuite extends FunSuite {
|
||||||
executor.shutdownNow()
|
executor.shutdownNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("sameThread") {
|
||||||
|
val callerThreadName = Thread.currentThread().getName()
|
||||||
|
val f = Future {
|
||||||
|
Thread.currentThread().getName()
|
||||||
|
}(ThreadUtils.sameThread)
|
||||||
|
val futureThreadName = Await.result(f, 10.seconds)
|
||||||
|
assert(futureThreadName === callerThreadName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,10 +18,10 @@
|
||||||
package org.apache.spark.sql.execution.joins
|
package org.apache.spark.sql.execution.joins
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.util.ThreadUtils
|
||||||
|
|
||||||
import scala.concurrent._
|
import scala.concurrent._
|
||||||
import scala.concurrent.duration._
|
import scala.concurrent.duration._
|
||||||
import scala.concurrent.ExecutionContext.Implicits.global
|
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.DeveloperApi
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
|
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
|
||||||
|
@ -64,7 +64,7 @@ case class BroadcastHashJoin(
|
||||||
val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
|
val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
|
||||||
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
|
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
|
||||||
sparkContext.broadcast(hashed)
|
sparkContext.broadcast(hashed)
|
||||||
}
|
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
|
||||||
|
|
||||||
protected override def doExecute(): RDD[Row] = {
|
protected override def doExecute(): RDD[Row] = {
|
||||||
val broadcastRelation = Await.result(broadcastFuture, timeout)
|
val broadcastRelation = Await.result(broadcastFuture, timeout)
|
||||||
|
@ -74,3 +74,9 @@ case class BroadcastHashJoin(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object BroadcastHashJoin {
|
||||||
|
|
||||||
|
private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
|
||||||
|
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 1024))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue