diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 465cc1fa7d..64e354e2e3 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -23,6 +23,7 @@ import scala.util.Try import org.apache.spark.scheduler.{JobSucceeded, JobWaiter} import org.apache.spark.scheduler.JobFailed +import org.apache.spark.rdd.RDD /** @@ -170,14 +171,13 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { } /** - * Executes some action enclosed in the closure. This execution of func is wrapped in a - * synchronized block to guarantee that this promise can only be cancelled when the task is - * waiting for + * Executes some action enclosed in the closure. To properly enable cancellation, the closure + * should use runJob implementation in this promise. See takeAsync for example. */ def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future { thread = Thread.currentThread try { - this.success(this.synchronized { + this.success({ if (cancelled) { // This action has been cancelled before this thread even started running. throw new InterruptedException @@ -191,6 +191,38 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { } } + /** + * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def runJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + partitionResultHandler: (Int, U) => Unit, + resultFunc: => R) { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. + val job = this.synchronized { + if (!cancelled) { + rdd.context.submitJob(rdd, processPartition, partitions, partitionResultHandler, resultFunc) + } else { + throw new SparkException("action has been cancelled") + } + } + + // Wait for the job to complete. If the action is cancelled (with an interrupt), + // cancel the job and stop the execution. This is not in a synchronized block because + // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. + try { + Await.ready(job, Duration.Inf) + } catch { + case e: InterruptedException => + job.cancel() + throw new SparkException("action has been cancelled") + } + } + /** * Returns whether the promise has been cancelled. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 6806b8730b..579832427e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.{Logging, CancellablePromise, FutureAction} @@ -90,22 +88,12 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val job = self.context.submitJob( - self, + promise.runJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf ++= data.take(num - buf.size), Unit) - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. - try { - Await.result(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw e - } partsScanned += numPartsToTry } buf.toSeq diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 758670bdbf..029f24a51b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -20,16 +20,13 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore import java.util.concurrent.atomic.AtomicInteger -import scala.concurrent.Await import scala.concurrent.future -import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.scheduler._ @@ -81,135 +78,154 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } -// -// test("countAsync") { -// assert(zeroPartRdd.countAsync().get() === 0) -// assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000) -// } -// -// test("collectAsync") { -// assert(zeroPartRdd.collectAsync().get() === Seq.empty) -// -// // Note that we sort the collected output because the order is indeterministic. -// val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted -// assert(collected === (1 to 1000)) -// } -// -// test("foreachAsync") { -// zeroPartRdd.foreachAsync(i => Unit).get() -// -// val accum = sc.accumulator(0) -// sc.parallelize(1 to 1000, 3).foreachAsync { i => -// accum += 1 -// }.get() -// assert(accum.value === 1000) -// } -// -// test("foreachPartitionAsync") { -// zeroPartRdd.foreachPartitionAsync(iter => Unit).get() -// -// val accum = sc.accumulator(0) -// sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => -// accum += 1 -// }.get() -// assert(accum.value === 9) -// } -// -// test("takeAsync") { -// def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) { -// // Note that we sort the collected output because the order is indeterministic. -// assert(rdd.takeAsync(num).get().size === input.take(num).size) -// } -// val input = Range(1, 1000) -// -// var nums = sc.parallelize(input, 1) -// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { -// testTake(nums, input, num) -// } -// -// nums = sc.parallelize(input, 2) -// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { -// testTake(nums, input, num) -// } -// -// nums = sc.parallelize(input, 100) -// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { -// testTake(nums, input, num) -// } -// -// nums = sc.parallelize(input, 1000) -// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { -// testTake(nums, input, num) -// } -// } -// -// /** -// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case -// * of a successful job execution. -// */ -// test("async success handling") { -// val f = sc.parallelize(1 to 10, 2).countAsync() -// -// // This semaphore is used to make sure our final assert waits until onComplete / onSuccess -// // finishes execution. -// val sem = new Semaphore(0) -// -// AsyncRDDActionsSuite.asyncSuccessHappened.set(0) -// f.onComplete { -// case scala.util.Success(res) => -// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() -// sem.release() -// case scala.util.Failure(e) => -// throw new Exception("Task should succeed") -// sem.release() -// } -// f.onSuccess { case a: Any => -// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() -// sem.release() -// } -// f.onFailure { case t => -// throw new Exception("Task should succeed") -// } -// assert(f.get() === 10) -// sem.acquire(2) -// assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2) -// } -// -// /** -// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case -// * of a failed job execution. -// */ -// test("async failure handling") { -// val f = sc.parallelize(1 to 10, 2).map { i => -// throw new Exception("intentional"); i -// }.countAsync() -// -// // This semaphore is used to make sure our final assert waits until onComplete / onFailure -// // finishes execution. -// val sem = new Semaphore(0) -// -// AsyncRDDActionsSuite.asyncFailureHappend.set(0) -// f.onComplete { -// case scala.util.Success(res) => -// throw new Exception("Task should fail") -// sem.release() -// case scala.util.Failure(e) => -// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() -// sem.release() -// } -// f.onSuccess { case a: Any => -// throw new Exception("Task should fail") -// } -// f.onFailure { case t => -// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() -// sem.release() -// } -// intercept[SparkException] { -// f.get() -// } -// sem.acquire(2) -// assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2) -// } + test("cancelling take action after some tasks have been launched") { + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) + future { + sem.acquire() + f.cancel() + } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + + test("countAsync") { + assert(zeroPartRdd.countAsync().get() === 0) + assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000) + } + + test("collectAsync") { + assert(zeroPartRdd.collectAsync().get() === Seq.empty) + + // Note that we sort the collected output because the order is indeterministic. + val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted + assert(collected === (1 to 1000)) + } + + test("foreachAsync") { + zeroPartRdd.foreachAsync(i => Unit).get() + + val accum = sc.accumulator(0) + sc.parallelize(1 to 1000, 3).foreachAsync { i => + accum += 1 + }.get() + assert(accum.value === 1000) + } + + test("foreachPartitionAsync") { + zeroPartRdd.foreachPartitionAsync(iter => Unit).get() + + val accum = sc.accumulator(0) + sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => + accum += 1 + }.get() + assert(accum.value === 9) + } + + test("takeAsync") { + def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) { + // Note that we sort the collected output because the order is indeterministic. + val expected = input.take(num).size + val saw = rdd.takeAsync(num).get().size + assert(saw == expected, "incorrect result for rdd with %d partitions (expected %d, saw %d)" + .format(rdd.partitions.size, expected, saw)) + } + val input = Range(1, 1000) + + var rdd = sc.parallelize(input, 1) + for (num <- Seq(0, 1, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 2) + for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 100) + for (num <- Seq(0, 1, 500, 501, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 1000) + for (num <- Seq(0, 1, 3, 999, 1000)) { + testTake(rdd, input, num) + } + } + + /** + * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case + * of a successful job execution. + */ + test("async success handling") { + val f = sc.parallelize(1 to 10, 2).countAsync() + + // This semaphore is used to make sure our final assert waits until onComplete / onSuccess + // finishes execution. + val sem = new Semaphore(0) + + AsyncRDDActionsSuite.asyncSuccessHappened.set(0) + f.onComplete { + case scala.util.Success(res) => + AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() + sem.release() + case scala.util.Failure(e) => + throw new Exception("Task should succeed") + sem.release() + } + f.onSuccess { case a: Any => + AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() + sem.release() + } + f.onFailure { case t => + throw new Exception("Task should succeed") + } + assert(f.get() === 10) + sem.acquire(2) + assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2) + } + + /** + * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case + * of a failed job execution. + */ + test("async failure handling") { + val f = sc.parallelize(1 to 10, 2).map { i => + throw new Exception("intentional"); i + }.countAsync() + + // This semaphore is used to make sure our final assert waits until onComplete / onFailure + // finishes execution. + val sem = new Semaphore(0) + + AsyncRDDActionsSuite.asyncFailureHappend.set(0) + f.onComplete { + case scala.util.Success(res) => + throw new Exception("Task should fail") + sem.release() + case scala.util.Failure(e) => + AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() + sem.release() + } + f.onSuccess { case a: Any => + throw new Exception("Task should fail") + } + f.onFailure { case t => + AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() + sem.release() + } + intercept[SparkException] { + f.get() + } + sem.acquire(2) + assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2) + } } object AsyncRDDActionsSuite { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 31f97fc139..d7e9ccafb3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } } visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection + assert(deps.size === 3) // ShuffledRDD, ParallelCollection, InterruptibleRDD. } test("join") {