Fixed the deadlock situation in multi-job actions and added more unit tests.
This commit is contained in:
parent
0353f74a9a
commit
3bd2890d2b
|
@ -23,6 +23,7 @@ import scala.util.Try
|
|||
|
||||
import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
|
||||
import org.apache.spark.scheduler.JobFailed
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
||||
/**
|
||||
|
@ -170,14 +171,13 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
|
|||
}
|
||||
|
||||
/**
|
||||
* Executes some action enclosed in the closure. This execution of func is wrapped in a
|
||||
* synchronized block to guarantee that this promise can only be cancelled when the task is
|
||||
* waiting for
|
||||
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
|
||||
* should use runJob implementation in this promise. See takeAsync for example.
|
||||
*/
|
||||
def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future {
|
||||
thread = Thread.currentThread
|
||||
try {
|
||||
this.success(this.synchronized {
|
||||
this.success({
|
||||
if (cancelled) {
|
||||
// This action has been cancelled before this thread even started running.
|
||||
throw new InterruptedException
|
||||
|
@ -191,6 +191,38 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
|
||||
* to enable cancellation.
|
||||
*/
|
||||
def runJob[T, U, R](
|
||||
rdd: RDD[T],
|
||||
processPartition: Iterator[T] => U,
|
||||
partitions: Seq[Int],
|
||||
partitionResultHandler: (Int, U) => Unit,
|
||||
resultFunc: => R) {
|
||||
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
|
||||
// command need to be in an atomic block.
|
||||
val job = this.synchronized {
|
||||
if (!cancelled) {
|
||||
rdd.context.submitJob(rdd, processPartition, partitions, partitionResultHandler, resultFunc)
|
||||
} else {
|
||||
throw new SparkException("action has been cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the job to complete. If the action is cancelled (with an interrupt),
|
||||
// cancel the job and stop the execution. This is not in a synchronized block because
|
||||
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
|
||||
try {
|
||||
Await.ready(job, Duration.Inf)
|
||||
} catch {
|
||||
case e: InterruptedException =>
|
||||
job.cancel()
|
||||
throw new SparkException("action has been cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the promise has been cancelled.
|
||||
*/
|
||||
|
|
|
@ -20,8 +20,6 @@ package org.apache.spark.rdd
|
|||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration.Duration
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
|
||||
import org.apache.spark.{Logging, CancellablePromise, FutureAction}
|
||||
|
@ -90,22 +88,12 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
|
|||
val left = num - buf.size
|
||||
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
|
||||
|
||||
val job = self.context.submitJob(
|
||||
self,
|
||||
promise.runJob(self,
|
||||
(it: Iterator[T]) => it.take(left).toArray,
|
||||
p,
|
||||
(index: Int, data: Array[T]) => buf ++= data.take(num - buf.size),
|
||||
Unit)
|
||||
|
||||
// Wait for the job to complete. If the action is cancelled (with an interrupt),
|
||||
// cancel the job and stop the execution.
|
||||
try {
|
||||
Await.result(job, Duration.Inf)
|
||||
} catch {
|
||||
case e: InterruptedException =>
|
||||
job.cancel()
|
||||
throw e
|
||||
}
|
||||
partsScanned += numPartsToTry
|
||||
}
|
||||
buf.toSeq
|
||||
|
|
|
@ -20,16 +20,13 @@ package org.apache.spark.rdd
|
|||
import java.util.concurrent.Semaphore
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.future
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
import org.apache.spark.SparkContext._
|
||||
import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
|
||||
import org.apache.spark.scheduler._
|
||||
|
||||
|
||||
|
@ -81,135 +78,154 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
|
|||
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
|
||||
}
|
||||
|
||||
//
|
||||
// test("countAsync") {
|
||||
// assert(zeroPartRdd.countAsync().get() === 0)
|
||||
// assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
|
||||
// }
|
||||
//
|
||||
// test("collectAsync") {
|
||||
// assert(zeroPartRdd.collectAsync().get() === Seq.empty)
|
||||
//
|
||||
// // Note that we sort the collected output because the order is indeterministic.
|
||||
// val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
|
||||
// assert(collected === (1 to 1000))
|
||||
// }
|
||||
//
|
||||
// test("foreachAsync") {
|
||||
// zeroPartRdd.foreachAsync(i => Unit).get()
|
||||
//
|
||||
// val accum = sc.accumulator(0)
|
||||
// sc.parallelize(1 to 1000, 3).foreachAsync { i =>
|
||||
// accum += 1
|
||||
// }.get()
|
||||
// assert(accum.value === 1000)
|
||||
// }
|
||||
//
|
||||
// test("foreachPartitionAsync") {
|
||||
// zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
|
||||
//
|
||||
// val accum = sc.accumulator(0)
|
||||
// sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
|
||||
// accum += 1
|
||||
// }.get()
|
||||
// assert(accum.value === 9)
|
||||
// }
|
||||
//
|
||||
// test("takeAsync") {
|
||||
// def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
|
||||
// // Note that we sort the collected output because the order is indeterministic.
|
||||
// assert(rdd.takeAsync(num).get().size === input.take(num).size)
|
||||
// }
|
||||
// val input = Range(1, 1000)
|
||||
//
|
||||
// var nums = sc.parallelize(input, 1)
|
||||
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
|
||||
// testTake(nums, input, num)
|
||||
// }
|
||||
//
|
||||
// nums = sc.parallelize(input, 2)
|
||||
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
|
||||
// testTake(nums, input, num)
|
||||
// }
|
||||
//
|
||||
// nums = sc.parallelize(input, 100)
|
||||
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
|
||||
// testTake(nums, input, num)
|
||||
// }
|
||||
//
|
||||
// nums = sc.parallelize(input, 1000)
|
||||
// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
|
||||
// testTake(nums, input, num)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
|
||||
// * of a successful job execution.
|
||||
// */
|
||||
// test("async success handling") {
|
||||
// val f = sc.parallelize(1 to 10, 2).countAsync()
|
||||
//
|
||||
// // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
|
||||
// // finishes execution.
|
||||
// val sem = new Semaphore(0)
|
||||
//
|
||||
// AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
|
||||
// f.onComplete {
|
||||
// case scala.util.Success(res) =>
|
||||
// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
|
||||
// sem.release()
|
||||
// case scala.util.Failure(e) =>
|
||||
// throw new Exception("Task should succeed")
|
||||
// sem.release()
|
||||
// }
|
||||
// f.onSuccess { case a: Any =>
|
||||
// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
|
||||
// sem.release()
|
||||
// }
|
||||
// f.onFailure { case t =>
|
||||
// throw new Exception("Task should succeed")
|
||||
// }
|
||||
// assert(f.get() === 10)
|
||||
// sem.acquire(2)
|
||||
// assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
|
||||
// * of a failed job execution.
|
||||
// */
|
||||
// test("async failure handling") {
|
||||
// val f = sc.parallelize(1 to 10, 2).map { i =>
|
||||
// throw new Exception("intentional"); i
|
||||
// }.countAsync()
|
||||
//
|
||||
// // This semaphore is used to make sure our final assert waits until onComplete / onFailure
|
||||
// // finishes execution.
|
||||
// val sem = new Semaphore(0)
|
||||
//
|
||||
// AsyncRDDActionsSuite.asyncFailureHappend.set(0)
|
||||
// f.onComplete {
|
||||
// case scala.util.Success(res) =>
|
||||
// throw new Exception("Task should fail")
|
||||
// sem.release()
|
||||
// case scala.util.Failure(e) =>
|
||||
// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
|
||||
// sem.release()
|
||||
// }
|
||||
// f.onSuccess { case a: Any =>
|
||||
// throw new Exception("Task should fail")
|
||||
// }
|
||||
// f.onFailure { case t =>
|
||||
// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
|
||||
// sem.release()
|
||||
// }
|
||||
// intercept[SparkException] {
|
||||
// f.get()
|
||||
// }
|
||||
// sem.acquire(2)
|
||||
// assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
|
||||
// }
|
||||
test("cancelling take action after some tasks have been launched") {
|
||||
// Add a listener to release the semaphore once any tasks are launched.
|
||||
val sem = new Semaphore(0)
|
||||
sc.dagScheduler.addSparkListener(new SparkListener {
|
||||
override def onTaskStart(taskStart: SparkListenerTaskStart) {
|
||||
sem.release()
|
||||
}
|
||||
})
|
||||
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
|
||||
future {
|
||||
sem.acquire()
|
||||
f.cancel()
|
||||
}
|
||||
val e = intercept[SparkException] { f.get() }
|
||||
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
|
||||
}
|
||||
|
||||
test("countAsync") {
|
||||
assert(zeroPartRdd.countAsync().get() === 0)
|
||||
assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
|
||||
}
|
||||
|
||||
test("collectAsync") {
|
||||
assert(zeroPartRdd.collectAsync().get() === Seq.empty)
|
||||
|
||||
// Note that we sort the collected output because the order is indeterministic.
|
||||
val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
|
||||
assert(collected === (1 to 1000))
|
||||
}
|
||||
|
||||
test("foreachAsync") {
|
||||
zeroPartRdd.foreachAsync(i => Unit).get()
|
||||
|
||||
val accum = sc.accumulator(0)
|
||||
sc.parallelize(1 to 1000, 3).foreachAsync { i =>
|
||||
accum += 1
|
||||
}.get()
|
||||
assert(accum.value === 1000)
|
||||
}
|
||||
|
||||
test("foreachPartitionAsync") {
|
||||
zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
|
||||
|
||||
val accum = sc.accumulator(0)
|
||||
sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
|
||||
accum += 1
|
||||
}.get()
|
||||
assert(accum.value === 9)
|
||||
}
|
||||
|
||||
test("takeAsync") {
|
||||
def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
|
||||
// Note that we sort the collected output because the order is indeterministic.
|
||||
val expected = input.take(num).size
|
||||
val saw = rdd.takeAsync(num).get().size
|
||||
assert(saw == expected, "incorrect result for rdd with %d partitions (expected %d, saw %d)"
|
||||
.format(rdd.partitions.size, expected, saw))
|
||||
}
|
||||
val input = Range(1, 1000)
|
||||
|
||||
var rdd = sc.parallelize(input, 1)
|
||||
for (num <- Seq(0, 1, 999, 1000)) {
|
||||
testTake(rdd, input, num)
|
||||
}
|
||||
|
||||
rdd = sc.parallelize(input, 2)
|
||||
for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
|
||||
testTake(rdd, input, num)
|
||||
}
|
||||
|
||||
rdd = sc.parallelize(input, 100)
|
||||
for (num <- Seq(0, 1, 500, 501, 999, 1000)) {
|
||||
testTake(rdd, input, num)
|
||||
}
|
||||
|
||||
rdd = sc.parallelize(input, 1000)
|
||||
for (num <- Seq(0, 1, 3, 999, 1000)) {
|
||||
testTake(rdd, input, num)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
|
||||
* of a successful job execution.
|
||||
*/
|
||||
test("async success handling") {
|
||||
val f = sc.parallelize(1 to 10, 2).countAsync()
|
||||
|
||||
// This semaphore is used to make sure our final assert waits until onComplete / onSuccess
|
||||
// finishes execution.
|
||||
val sem = new Semaphore(0)
|
||||
|
||||
AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
|
||||
f.onComplete {
|
||||
case scala.util.Success(res) =>
|
||||
AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
|
||||
sem.release()
|
||||
case scala.util.Failure(e) =>
|
||||
throw new Exception("Task should succeed")
|
||||
sem.release()
|
||||
}
|
||||
f.onSuccess { case a: Any =>
|
||||
AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
|
||||
sem.release()
|
||||
}
|
||||
f.onFailure { case t =>
|
||||
throw new Exception("Task should succeed")
|
||||
}
|
||||
assert(f.get() === 10)
|
||||
sem.acquire(2)
|
||||
assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
|
||||
* of a failed job execution.
|
||||
*/
|
||||
test("async failure handling") {
|
||||
val f = sc.parallelize(1 to 10, 2).map { i =>
|
||||
throw new Exception("intentional"); i
|
||||
}.countAsync()
|
||||
|
||||
// This semaphore is used to make sure our final assert waits until onComplete / onFailure
|
||||
// finishes execution.
|
||||
val sem = new Semaphore(0)
|
||||
|
||||
AsyncRDDActionsSuite.asyncFailureHappend.set(0)
|
||||
f.onComplete {
|
||||
case scala.util.Success(res) =>
|
||||
throw new Exception("Task should fail")
|
||||
sem.release()
|
||||
case scala.util.Failure(e) =>
|
||||
AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
|
||||
sem.release()
|
||||
}
|
||||
f.onSuccess { case a: Any =>
|
||||
throw new Exception("Task should fail")
|
||||
}
|
||||
f.onFailure { case t =>
|
||||
AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
|
||||
sem.release()
|
||||
}
|
||||
intercept[SparkException] {
|
||||
f.get()
|
||||
}
|
||||
sem.acquire(2)
|
||||
assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
|
||||
}
|
||||
}
|
||||
|
||||
object AsyncRDDActionsSuite {
|
||||
|
|
|
@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
|
|||
}
|
||||
}
|
||||
visit(sums)
|
||||
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
|
||||
assert(deps.size === 3) // ShuffledRDD, ParallelCollection, InterruptibleRDD.
|
||||
}
|
||||
|
||||
test("join") {
|
||||
|
|
Loading…
Reference in a new issue