[SPARK-25045][CORE] Make RDDBarrier.mapParititions similar to RDD.mapPartitions

## What changes were proposed in this pull request?

Signature of the function passed to `RDDBarrier.mapPartitions()` is different from that of `RDD.mapPartitions`. The later doesn’t take a `TaskContext`. We shall make the function signature the same to avoid confusion and misusage.

This PR proposes the following API changes:
- In `RDDBarrier`, migrate `mapPartitions` from
   ```
        def mapPartitions[S: ClassTag](
            f: (Iterator[T], BarrierTaskContext) => Iterator[S],
            preservesPartitioning: Boolean = false): RDD[S]
        }
   ```
    to
   ```
        def mapPartitions[S: ClassTag](
            f: Iterator[T] => Iterator[S],
            preservesPartitioning: Boolean = false): RDD[S]
        }
   ```
- Add new static method to get a `BarrierTaskContext`:
   ```
        object BarrierTaskContext {
           def get(): BarrierTaskContext
        }
   ```

## How was this patch tested?

Existing test cases.

Author: Xingbo Jiang <xingbo.jiang@databricks.com>

Closes #22026 from jiangxb1987/mapPartitions.
This commit is contained in:
Xingbo Jiang 2018-08-07 17:32:41 -07:00 committed by Xiangrui Meng
parent 66699c5c30
commit d90f1336d8
7 changed files with 45 additions and 29 deletions

View file

@ -72,7 +72,8 @@ class BarrierTaskContext(
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { (iter, context) =>
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
@ -85,7 +86,8 @@ class BarrierTaskContext(
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { (iter, context) =>
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
@ -152,3 +154,11 @@ class BarrierTaskContext(
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
}
}
object BarrierTaskContext {
/**
* Return the currently active BarrierTaskContext. This can be called inside of user functions to
* access contextual information about running barrier tasks.
*/
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]
}

View file

@ -28,7 +28,7 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
/**
* :: Experimental ::
* Maps partitions together with a provided [[org.apache.spark.BarrierTaskContext]].
* Generate a new barrier RDD by applying a function to each partitions of the prev RDD.
*
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
* should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys.
@ -36,13 +36,12 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
@Experimental
@Since("2.4.0")
def mapPartitions[S: ClassTag](
f: (Iterator[T], BarrierTaskContext) => Iterator[S],
f: Iterator[T] => Iterator[S],
preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope {
val cleanedF = rdd.sparkContext.clean(f)
new MapPartitionsRDD(
rdd,
(context: TaskContext, index: Int, iter: Iterator[T]) =>
cleanedF(iter, context.asInstanceOf[BarrierTaskContext]),
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
preservesPartitioning,
isFromBarrier = true
)

View file

@ -61,7 +61,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
val rdd = prunedRdd
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
testSubmitJob(sc, rdd,
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
}
@ -71,7 +71,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
val rdd = prunedRdd
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
.repartition(2)
.map(x => x + 1)
testSubmitJob(sc, rdd,
@ -84,7 +84,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
val rdd = prunedRdd
.repartition(2)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
// Should be able to submit job and run successfully.
val result = rdd.collect().sorted
assert(result === Seq(6, 7, 8, 9, 10))
@ -94,7 +94,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
sc = createSparkContext()
val rdd = sc.parallelize(1 to 10, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
testSubmitJob(sc, rdd, Some(Seq(1, 3)),
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
}
@ -103,7 +103,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
sc = createSparkContext()
val rdd1 = sc.parallelize(1 to 10, 2)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
val rdd2 = sc.parallelize(1 to 20, 2)
val rdd3 = rdd1
.union(rdd2)
@ -117,7 +117,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
sc = createSparkContext()
val rdd = sc.parallelize(1 to 10, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
.coalesce(1)
// Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage
// only launches 1 task.
@ -129,10 +129,10 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
sc = createSparkContext()
val rdd1 = sc.parallelize(1 to 10, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
val rdd2 = sc.parallelize(11 to 20, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
val rdd3 = rdd1
.zip(rdd2)
.map(x => x._1 + x._2)
@ -144,7 +144,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
sc = createSparkContext()
val rdd1 = sc.parallelize(1 to 10, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
val rdd2 = sc.parallelize(11 to 20, 4)
val rdd3 = rdd1
.zip(rdd2)
@ -164,7 +164,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
val rdd = sc.parallelize(1 to 10, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
testSubmitJob(sc, rdd,
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION)
}
@ -179,7 +179,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
val rdd = sc.parallelize(1 to 10, 4)
.barrier()
.mapPartitions((iter, context) => iter)
.mapPartitions(iter => iter)
.repartition(2)
.map(x => x + 1)
testSubmitJob(sc, rdd,

View file

@ -632,7 +632,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
val conf = new SparkConf().setAppName("test").setMaster("local[2]")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// If we don't get the expected taskInfos, the job shall abort due to stage failure.
if (context.getTaskInfos().length != 2) {
throw new SparkException("Expected taksInfos length is 2, actual length is " +
@ -654,7 +655,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// If we don't get the expected taskInfos, the job shall abort due to stage failure.
if (context.getTaskInfos().length != 2) {
throw new SparkException("Expected taksInfos length is 2, actual length is " +

View file

@ -25,19 +25,19 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext {
val rdd = sc.parallelize(1 to 10, 4)
assert(rdd.isBarrier() === false)
val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter)
val rdd2 = rdd.barrier().mapPartitions(iter => iter)
assert(rdd2.isBarrier() === true)
}
test("create an RDDBarrier in the middle of a chain of RDDs") {
val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2)
val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).map(x => (x, x + 1))
val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1))
assert(rdd2.isBarrier() === true)
}
test("RDDBarrier with shuffle") {
val rdd = sc.parallelize(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).repartition(2)
val rdd2 = rdd.barrier().mapPartitions(iter => iter).repartition(2)
assert(rdd2.isBarrier() === false)
}
}

View file

@ -31,7 +31,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
context.barrier()
@ -49,7 +50,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
context.barrier()
@ -79,7 +81,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Task 3 shall sleep 2000ms to ensure barrier() call timeout
if (context.taskAttemptId == 3) {
Thread.sleep(2000)
@ -103,7 +106,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
if (context.taskAttemptId != 0) {
context.barrier()
}
@ -125,7 +129,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
try {
if (context.taskAttemptId == 0) {
// Due to some non-obvious reason, the code can trigger an Exception and skip the

View file

@ -1062,7 +1062,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it)
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
@ -1091,7 +1091,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it)
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)