[SPARK-5214][Core] Add EventLoop and change DAGScheduler to an EventLoop

This PR adds a simple `EventLoop` and use it to replace Actor in DAGScheduler. `EventLoop` is a general class to support that posting events in multiple threads and handling events in a single event thread.

Author: zsxwing <zsxwing@gmail.com>

Closes #4016 from zsxwing/event-loop and squashes the following commits:

aefa1ce [zsxwing] Add protected to on*** methods
5cfac83 [zsxwing] Remove null check of eventProcessLoop
dba35b2 [zsxwing] Add a test that onReceive swallows InterruptException
460f7b3 [zsxwing] Use volatile instead of Atomic things in unit tests
227bf33 [zsxwing] Add a stop flag and some tests
37f79c6 [zsxwing] Fix docs
55fb6f6 [zsxwing] Add private[spark] to EventLoop
1f73eac [zsxwing] Fix the import order
3b2e59c [zsxwing] Add EventLoop and change DAGScheduler to an EventLoop
This commit is contained in:
zsxwing 2015-01-19 18:15:51 -08:00 committed by Reynold Xin
parent 74de94ea6d
commit e69fb8c75a
4 changed files with 372 additions and 98 deletions

View file

@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
@ -28,8 +29,6 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import akka.actor._
import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout
@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@ -67,8 +66,6 @@ class DAGScheduler(
clock: Clock = SystemClock)
extends Logging {
import DAGScheduler._
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
@ -112,14 +109,10 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
private[scheduler] var eventProcessActor: ActorRef = _
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
@ -127,27 +120,20 @@ class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
implicit val timeout = Timeout(30 seconds)
val initEventActorReply =
dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
eventProcessActor = Await.result(initEventActorReply, timeout.duration).
asInstanceOf[ActorRef]
}
private val messageScheduler =
Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))
initializeEventProcessActor()
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessActor ! BeginEvent(task, taskInfo)
eventProcessLoop.post(BeginEvent(task, taskInfo))
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(taskInfo: TaskInfo) {
eventProcessActor ! GettingResultEvent(taskInfo)
eventProcessLoop.post(GettingResultEvent(taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
@ -158,7 +144,8 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
eventProcessLoop.post(
CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
/**
@ -180,18 +167,18 @@ class DAGScheduler(
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
eventProcessActor ! ExecutorLost(execId)
eventProcessLoop.post(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
def executorAdded(execId: String, host: String) {
eventProcessActor ! ExecutorAdded(execId, host)
eventProcessLoop.post(ExecutorAdded(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
eventProcessActor ! TaskSetFailed(taskSet, reason)
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
@ -496,8 +483,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
waiter
}
@ -537,8 +524,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
@ -547,19 +534,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
eventProcessActor ! JobCancelled(jobId)
eventProcessLoop.post(JobCancelled(jobId))
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
eventProcessActor ! JobGroupCancelled(groupId)
eventProcessLoop.post(JobGroupCancelled(groupId))
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
eventProcessActor ! AllJobsCancelled
eventProcessLoop.post(AllJobsCancelled)
}
private[scheduler] def doCancelAllJobs() {
@ -575,7 +562,7 @@ class DAGScheduler(
* Cancel all jobs associated with a running or scheduled stage.
*/
def cancelStage(stageId: Int) {
eventProcessActor ! StageCancelled(stageId)
eventProcessLoop.post(StageCancelled(stageId))
}
/**
@ -1063,16 +1050,15 @@ class DAGScheduler(
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty && eventProcessActor != null) {
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
import env.actorSystem.dispatcher
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
env.actorSystem.scheduler.scheduleOnce(
RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
@ -1330,40 +1316,21 @@ class DAGScheduler(
def stop() {
logInfo("Stopping DAGScheduler")
dagSchedulerActorSupervisor ! PoisonPill
eventProcessLoop.stop()
taskScheduler.stop()
}
// Start the event thread at the end of the constructor
eventProcessLoop.start()
}
private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
extends Actor with Logging {
override val supervisorStrategy =
OneForOneStrategy() {
case x: Exception =>
logError("eventProcesserActor failed; shutting down SparkContext", x)
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
Stop
}
def receive = {
case p: Props => sender ! context.actorOf(p)
case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
}
}
private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
extends Actor with Logging {
private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {
/**
* The main event loop of the DAG scheduler.
*/
def receive = {
override def onReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
listener, properties)
@ -1402,7 +1369,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.resubmitFailedStages()
}
override def postStop() {
override def onError(e: Throwable): Unit = {
logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
}
override def onStop() {
// Cancel any active jobs in postStop hook
dagScheduler.cleanUpAfterSchedulerStop()
}
@ -1412,9 +1389,5 @@ private[spark] object DAGScheduler {
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 200.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
val RESUBMIT_TIMEOUT = 200
}

View file

@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
import scala.util.control.NonFatal
import org.apache.spark.Logging
/**
* An event loop to receive events from the caller and process all events in the event thread. It
* will start an exclusive event thread to process all events.
*
* Note: The event queue will grow indefinitely. So subclasses should make sure `onReceive` can
* handle events in time to avoid the potential OOM.
*/
private[spark] abstract class EventLoop[E](name: String) extends Logging {
private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()
private val stopped = new AtomicBoolean(false)
private val eventThread = new Thread(name) {
setDaemon(true)
override def run(): Unit = {
try {
while (!stopped.get) {
val event = eventQueue.take()
try {
onReceive(event)
} catch {
case NonFatal(e) => {
try {
onError(e)
} catch {
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
}
}
}
} catch {
case ie: InterruptedException => // exit even if eventQueue is not empty
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
}
}
def start(): Unit = {
if (stopped.get) {
throw new IllegalStateException(name + " has already been stopped")
}
// Call onStart before starting the event thread to make sure it happens before onReceive
onStart()
eventThread.start()
}
def stop(): Unit = {
if (stopped.compareAndSet(false, true)) {
eventThread.interrupt()
eventThread.join()
// Call onStop after the event thread exits to make sure onReceive happens before onStop
onStop()
} else {
// Keep quiet to allow calling `stop` multiple times.
}
}
/**
* Put the event into the event queue. The event thread will process it later.
*/
def post(event: E): Unit = {
eventQueue.put(event)
}
/**
* Return if the event thread has already been started but not yet stopped.
*/
def isActive: Boolean = eventThread.isAlive
/**
* Invoked when `start()` is called but before the event thread starts.
*/
protected def onStart(): Unit = {}
/**
* Invoked when `stop()` is called and the event thread exits.
*/
protected def onStop(): Unit = {}
/**
* Invoked in the event thread when polling events from the event queue.
*
* Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked
* and cannot process events in time. If you want to call some blocking actions, run them in
* another thread.
*/
protected def onReceive(event: E): Unit
/**
* Invoked if `onReceive` throws any non fatal error. Any non fatal error thrown from `onError`
* will be ignored.
*/
protected def onError(e: Throwable): Unit
}

View file

@ -19,9 +19,8 @@ package org.apache.spark.scheduler
import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map}
import scala.language.reflectiveCalls
import scala.util.control.NonFatal
import akka.actor._
import akka.testkit.{ImplicitSender, TestKit, TestActorRef}
import org.scalatest.{BeforeAndAfter, FunSuiteLike}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
@ -33,10 +32,16 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.util.CallSite
import org.apache.spark.executor.TaskMetrics
class BuggyDAGEventProcessActor extends Actor {
val state = 0
def receive = {
case _ => throw new SparkException("error")
class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
extends DAGSchedulerEventProcessLoop(dagScheduler) {
override def post(event: DAGSchedulerEvent): Unit = {
try {
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
}
}
}
@ -65,8 +70,7 @@ class MyRDD(
class DAGSchedulerSuiteDummyException extends Exception
class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike
with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts {
class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts {
val conf = new SparkConf
/** Set of TaskSets the DAGScheduler has requested executed. */
@ -113,7 +117,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null
var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
/**
* Set of cache locations to return from our mock BlockManagerMaster.
@ -167,13 +171,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
runLocallyWithinThread(job)
}
}
dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor](
Props(classOf[DAGSchedulerEventProcessActor], scheduler))(system)
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
}
override def afterAll() {
super.afterAll()
TestKit.shutdownActorSystem(system)
}
/**
@ -190,7 +192,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
* DAGScheduler event loop.
*/
private def runEvent(event: DAGSchedulerEvent) {
dagEventProcessTestActor.receive(event)
dagEventProcessLoopTester.post(event)
}
/**
@ -397,8 +399,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
runLocallyWithinThread(job)
}
}
dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor](
Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system)
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler)
val jobId = submit(new MyRDD(sc, 1, Nil), Array(0))
cancel(jobId)
// Because the job wasn't actually cancelled, we shouldn't have received a failure message.
@ -726,18 +727,6 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assert(sc.parallelize(1 to 10, 2).first() === 1)
}
test("DAGSchedulerActorSupervisor closes the SparkContext when EventProcessActor crashes") {
val actorSystem = ActorSystem("test")
val supervisor = actorSystem.actorOf(
Props(classOf[DAGSchedulerActorSupervisor], scheduler), "dagSupervisor")
supervisor ! Props[BuggyDAGEventProcessActor]
val child = expectMsgType[ActorRef]
watch(child)
child ! "hi"
expectMsgPF(){ case Terminated(child) => () }
assert(scheduler.sc.dagScheduler === null)
}
test("accumulator not calculated for resubmitted result stage") {
//just for register
val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)

View file

@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.concurrent.CountDownLatch
import scala.collection.mutable
import scala.concurrent.duration._
import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts
import org.scalatest.FunSuite
class EventLoopSuite extends FunSuite with Timeouts {
test("EventLoop") {
val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int]
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
buffer += event
}
override def onError(e: Throwable): Unit = {}
}
eventLoop.start()
(1 to 100).foreach(eventLoop.post)
eventually(timeout(5 seconds), interval(200 millis)) {
assert((1 to 100) === buffer.toSeq)
}
eventLoop.stop()
}
test("EventLoop: start and stop") {
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {}
override def onError(e: Throwable): Unit = {}
}
assert(false === eventLoop.isActive)
eventLoop.start()
assert(true === eventLoop.isActive)
eventLoop.stop()
assert(false === eventLoop.isActive)
}
test("EventLoop: onError") {
val e = new RuntimeException("Oops")
@volatile var receivedError: Throwable = null
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
throw e
}
override def onError(e: Throwable): Unit = {
receivedError = e
}
}
eventLoop.start()
eventLoop.post(1)
eventually(timeout(5 seconds), interval(200 millis)) {
assert(e === receivedError)
}
eventLoop.stop()
}
test("EventLoop: error thrown from onError should not crash the event thread") {
val e = new RuntimeException("Oops")
@volatile var receivedError: Throwable = null
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
throw e
}
override def onError(e: Throwable): Unit = {
receivedError = e
throw new RuntimeException("Oops")
}
}
eventLoop.start()
eventLoop.post(1)
eventually(timeout(5 seconds), interval(200 millis)) {
assert(e === receivedError)
assert(eventLoop.isActive)
}
eventLoop.stop()
}
test("EventLoop: calling stop multiple times should only call onStop once") {
var onStopTimes = 0
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
}
override def onError(e: Throwable): Unit = {
}
override def onStop(): Unit = {
onStopTimes += 1
}
}
eventLoop.start()
eventLoop.stop()
eventLoop.stop()
eventLoop.stop()
assert(1 === onStopTimes)
}
test("EventLoop: post event in multiple threads") {
@volatile var receivedEventsCount = 0
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
receivedEventsCount += 1
}
override def onError(e: Throwable): Unit = {
}
}
eventLoop.start()
val threadNum = 5
val eventsFromEachThread = 100
(1 to threadNum).foreach { _ =>
new Thread() {
override def run(): Unit = {
(1 to eventsFromEachThread).foreach(eventLoop.post)
}
}.start()
}
eventually(timeout(5 seconds), interval(200 millis)) {
assert(threadNum * eventsFromEachThread === receivedEventsCount)
}
eventLoop.stop()
}
test("EventLoop: onReceive swallows InterruptException") {
val onReceiveLatch = new CountDownLatch(1)
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
onReceiveLatch.countDown()
try {
Thread.sleep(5000)
} catch {
case ie: InterruptedException => // swallow
}
}
override def onError(e: Throwable): Unit = {
}
}
eventLoop.start()
eventLoop.post(1)
failAfter(5 seconds) {
// Wait until we enter `onReceive`
onReceiveLatch.await()
eventLoop.stop()
}
assert(false === eventLoop.isActive)
}
}