Fixed the deadlock situation in multi-job actions and added more unit tests.

This commit is contained in:
Reynold Xin 2013-10-10 12:07:09 -07:00
parent 0353f74a9a
commit 3bd2890d2b
4 changed files with 186 additions and 150 deletions

View file

@ -23,6 +23,7 @@ import scala.util.Try
import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
import org.apache.spark.scheduler.JobFailed
import org.apache.spark.rdd.RDD
/**
@ -170,14 +171,13 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
}
/**
* Executes some action enclosed in the closure. This execution of func is wrapped in a
* synchronized block to guarantee that this promise can only be cancelled when the task is
* waiting for
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
* should use runJob implementation in this promise. See takeAsync for example.
*/
def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future {
thread = Thread.currentThread
try {
this.success(this.synchronized {
this.success({
if (cancelled) {
// This action has been cancelled before this thread even started running.
throw new InterruptedException
@ -191,6 +191,38 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
}
}
/**
* Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
* to enable cancellation.
*/
def runJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
partitionResultHandler: (Int, U) => Unit,
resultFunc: => R) {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
if (!cancelled) {
rdd.context.submitJob(rdd, processPartition, partitions, partitionResultHandler, resultFunc)
} else {
throw new SparkException("action has been cancelled")
}
}
// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
try {
Await.ready(job, Duration.Inf)
} catch {
case e: InterruptedException =>
job.cancel()
throw new SparkException("action has been cancelled")
}
}
/**
* Returns whether the promise has been cancelled.
*/

View file

@ -20,8 +20,6 @@ package org.apache.spark.rdd
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.{Logging, CancellablePromise, FutureAction}
@ -90,22 +88,12 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val job = self.context.submitJob(
self,
promise.runJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf ++= data.take(num - buf.size),
Unit)
// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution.
try {
Await.result(job, Duration.Inf)
} catch {
case e: InterruptedException =>
job.cancel()
throw e
}
partsScanned += numPartsToTry
}
buf.toSeq

View file

@ -20,16 +20,13 @@ package org.apache.spark.rdd
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Await
import scala.concurrent.future
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
import org.apache.spark.scheduler._
@ -81,135 +78,154 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
}
//
// test("countAsync") {
// assert(zeroPartRdd.countAsync().get() === 0)
// assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
// }
//
// test("collectAsync") {
// assert(zeroPartRdd.collectAsync().get() === Seq.empty)
//
// // Note that we sort the collected output because the order is indeterministic.
// val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
// assert(collected === (1 to 1000))
// }
//
// test("foreachAsync") {
// zeroPartRdd.foreachAsync(i => Unit).get()
//
// val accum = sc.accumulator(0)
// sc.parallelize(1 to 1000, 3).foreachAsync { i =>
// accum += 1
// }.get()
// assert(accum.value === 1000)
// }
//
// test("foreachPartitionAsync") {
// zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
//
// val accum = sc.accumulator(0)
// sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
// accum += 1
// }.get()
// assert(accum.value === 9)
// }
//
// test("takeAsync") {
// def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
// // Note that we sort the collected output because the order is indeterministic.
// assert(rdd.takeAsync(num).get().size === input.take(num).size)
// }
// val input = Range(1, 1000)
//
// var nums = sc.parallelize(input, 1)
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
// testTake(nums, input, num)
// }
//
// nums = sc.parallelize(input, 2)
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
// testTake(nums, input, num)
// }
//
// nums = sc.parallelize(input, 100)
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
// testTake(nums, input, num)
// }
//
// nums = sc.parallelize(input, 1000)
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
// testTake(nums, input, num)
// }
// }
//
// /**
// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
// * of a successful job execution.
// */
// test("async success handling") {
// val f = sc.parallelize(1 to 10, 2).countAsync()
//
// // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
// // finishes execution.
// val sem = new Semaphore(0)
//
// AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
// f.onComplete {
// case scala.util.Success(res) =>
// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
// sem.release()
// case scala.util.Failure(e) =>
// throw new Exception("Task should succeed")
// sem.release()
// }
// f.onSuccess { case a: Any =>
// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
// sem.release()
// }
// f.onFailure { case t =>
// throw new Exception("Task should succeed")
// }
// assert(f.get() === 10)
// sem.acquire(2)
// assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
// }
//
// /**
// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
// * of a failed job execution.
// */
// test("async failure handling") {
// val f = sc.parallelize(1 to 10, 2).map { i =>
// throw new Exception("intentional"); i
// }.countAsync()
//
// // This semaphore is used to make sure our final assert waits until onComplete / onFailure
// // finishes execution.
// val sem = new Semaphore(0)
//
// AsyncRDDActionsSuite.asyncFailureHappend.set(0)
// f.onComplete {
// case scala.util.Success(res) =>
// throw new Exception("Task should fail")
// sem.release()
// case scala.util.Failure(e) =>
// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
// sem.release()
// }
// f.onSuccess { case a: Any =>
// throw new Exception("Task should fail")
// }
// f.onFailure { case t =>
// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
// sem.release()
// }
// intercept[SparkException] {
// f.get()
// }
// sem.acquire(2)
// assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
// }
test("cancelling take action after some tasks have been launched") {
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.dagScheduler.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem.release()
}
})
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
future {
sem.acquire()
f.cancel()
}
val e = intercept[SparkException] { f.get() }
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
}
test("countAsync") {
assert(zeroPartRdd.countAsync().get() === 0)
assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
}
test("collectAsync") {
assert(zeroPartRdd.collectAsync().get() === Seq.empty)
// Note that we sort the collected output because the order is indeterministic.
val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
assert(collected === (1 to 1000))
}
test("foreachAsync") {
zeroPartRdd.foreachAsync(i => Unit).get()
val accum = sc.accumulator(0)
sc.parallelize(1 to 1000, 3).foreachAsync { i =>
accum += 1
}.get()
assert(accum.value === 1000)
}
test("foreachPartitionAsync") {
zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
val accum = sc.accumulator(0)
sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
accum += 1
}.get()
assert(accum.value === 9)
}
test("takeAsync") {
def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
// Note that we sort the collected output because the order is indeterministic.
val expected = input.take(num).size
val saw = rdd.takeAsync(num).get().size
assert(saw == expected, "incorrect result for rdd with %d partitions (expected %d, saw %d)"
.format(rdd.partitions.size, expected, saw))
}
val input = Range(1, 1000)
var rdd = sc.parallelize(input, 1)
for (num <- Seq(0, 1, 999, 1000)) {
testTake(rdd, input, num)
}
rdd = sc.parallelize(input, 2)
for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
testTake(rdd, input, num)
}
rdd = sc.parallelize(input, 100)
for (num <- Seq(0, 1, 500, 501, 999, 1000)) {
testTake(rdd, input, num)
}
rdd = sc.parallelize(input, 1000)
for (num <- Seq(0, 1, 3, 999, 1000)) {
testTake(rdd, input, num)
}
}
/**
* Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
* of a successful job execution.
*/
test("async success handling") {
val f = sc.parallelize(1 to 10, 2).countAsync()
// This semaphore is used to make sure our final assert waits until onComplete / onSuccess
// finishes execution.
val sem = new Semaphore(0)
AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
f.onComplete {
case scala.util.Success(res) =>
AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
sem.release()
case scala.util.Failure(e) =>
throw new Exception("Task should succeed")
sem.release()
}
f.onSuccess { case a: Any =>
AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
sem.release()
}
f.onFailure { case t =>
throw new Exception("Task should succeed")
}
assert(f.get() === 10)
sem.acquire(2)
assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
}
/**
* Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
* of a failed job execution.
*/
test("async failure handling") {
val f = sc.parallelize(1 to 10, 2).map { i =>
throw new Exception("intentional"); i
}.countAsync()
// This semaphore is used to make sure our final assert waits until onComplete / onFailure
// finishes execution.
val sem = new Semaphore(0)
AsyncRDDActionsSuite.asyncFailureHappend.set(0)
f.onComplete {
case scala.util.Success(res) =>
throw new Exception("Task should fail")
sem.release()
case scala.util.Failure(e) =>
AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
sem.release()
}
f.onSuccess { case a: Any =>
throw new Exception("Task should fail")
}
f.onFailure { case t =>
AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
sem.release()
}
intercept[SparkException] {
f.get()
}
sem.acquire(2)
assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
}
}
object AsyncRDDActionsSuite {

View file

@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
}
visit(sums)
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
assert(deps.size === 3) // ShuffledRDD, ParallelCollection, InterruptibleRDD.
}
test("join") {