[SPARK-25045][CORE] Make RDDBarrier.mapParititions
similar to RDD.mapPartitions
## What changes were proposed in this pull request? Signature of the function passed to `RDDBarrier.mapPartitions()` is different from that of `RDD.mapPartitions`. The later doesn’t take a `TaskContext`. We shall make the function signature the same to avoid confusion and misusage. This PR proposes the following API changes: - In `RDDBarrier`, migrate `mapPartitions` from ``` def mapPartitions[S: ClassTag]( f: (Iterator[T], BarrierTaskContext) => Iterator[S], preservesPartitioning: Boolean = false): RDD[S] } ``` to ``` def mapPartitions[S: ClassTag]( f: Iterator[T] => Iterator[S], preservesPartitioning: Boolean = false): RDD[S] } ``` - Add new static method to get a `BarrierTaskContext`: ``` object BarrierTaskContext { def get(): BarrierTaskContext } ``` ## How was this patch tested? Existing test cases. Author: Xingbo Jiang <xingbo.jiang@databricks.com> Closes #22026 from jiangxb1987/mapPartitions.
This commit is contained in:
parent
66699c5c30
commit
d90f1336d8
|
@ -72,7 +72,8 @@ class BarrierTaskContext(
|
|||
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
|
||||
* shall lead to timeout of the function call.
|
||||
* {{{
|
||||
* rdd.barrier().mapPartitions { (iter, context) =>
|
||||
* rdd.barrier().mapPartitions { iter =>
|
||||
* val context = BarrierTaskContext.get()
|
||||
* if (context.partitionId() == 0) {
|
||||
* // Do nothing.
|
||||
* } else {
|
||||
|
@ -85,7 +86,8 @@ class BarrierTaskContext(
|
|||
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
|
||||
* second function call.
|
||||
* {{{
|
||||
* rdd.barrier().mapPartitions { (iter, context) =>
|
||||
* rdd.barrier().mapPartitions { iter =>
|
||||
* val context = BarrierTaskContext.get()
|
||||
* try {
|
||||
* // Do something that might throw an Exception.
|
||||
* doSomething()
|
||||
|
@ -152,3 +154,11 @@ class BarrierTaskContext(
|
|||
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
|
||||
}
|
||||
}
|
||||
|
||||
object BarrierTaskContext {
|
||||
/**
|
||||
* Return the currently active BarrierTaskContext. This can be called inside of user functions to
|
||||
* access contextual information about running barrier tasks.
|
||||
*/
|
||||
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
|
|||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Maps partitions together with a provided [[org.apache.spark.BarrierTaskContext]].
|
||||
* Generate a new barrier RDD by applying a function to each partitions of the prev RDD.
|
||||
*
|
||||
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
|
||||
* should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys.
|
||||
|
@ -36,13 +36,12 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
|
|||
@Experimental
|
||||
@Since("2.4.0")
|
||||
def mapPartitions[S: ClassTag](
|
||||
f: (Iterator[T], BarrierTaskContext) => Iterator[S],
|
||||
f: Iterator[T] => Iterator[S],
|
||||
preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope {
|
||||
val cleanedF = rdd.sparkContext.clean(f)
|
||||
new MapPartitionsRDD(
|
||||
rdd,
|
||||
(context: TaskContext, index: Int, iter: Iterator[T]) =>
|
||||
cleanedF(iter, context.asInstanceOf[BarrierTaskContext]),
|
||||
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
|
||||
preservesPartitioning,
|
||||
isFromBarrier = true
|
||||
)
|
||||
|
|
|
@ -61,7 +61,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
|
||||
val rdd = prunedRdd
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
testSubmitJob(sc, rdd,
|
||||
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
|
||||
val rdd = prunedRdd
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
.repartition(2)
|
||||
.map(x => x + 1)
|
||||
testSubmitJob(sc, rdd,
|
||||
|
@ -84,7 +84,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
val rdd = prunedRdd
|
||||
.repartition(2)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
// Should be able to submit job and run successfully.
|
||||
val result = rdd.collect().sorted
|
||||
assert(result === Seq(6, 7, 8, 9, 10))
|
||||
|
@ -94,7 +94,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
sc = createSparkContext()
|
||||
val rdd = sc.parallelize(1 to 10, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
testSubmitJob(sc, rdd, Some(Seq(1, 3)),
|
||||
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
sc = createSparkContext()
|
||||
val rdd1 = sc.parallelize(1 to 10, 2)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
val rdd2 = sc.parallelize(1 to 20, 2)
|
||||
val rdd3 = rdd1
|
||||
.union(rdd2)
|
||||
|
@ -117,7 +117,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
sc = createSparkContext()
|
||||
val rdd = sc.parallelize(1 to 10, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
.coalesce(1)
|
||||
// Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage
|
||||
// only launches 1 task.
|
||||
|
@ -129,10 +129,10 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
sc = createSparkContext()
|
||||
val rdd1 = sc.parallelize(1 to 10, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
val rdd2 = sc.parallelize(11 to 20, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
val rdd3 = rdd1
|
||||
.zip(rdd2)
|
||||
.map(x => x._1 + x._2)
|
||||
|
@ -144,7 +144,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
sc = createSparkContext()
|
||||
val rdd1 = sc.parallelize(1 to 10, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
val rdd2 = sc.parallelize(11 to 20, 4)
|
||||
val rdd3 = rdd1
|
||||
.zip(rdd2)
|
||||
|
@ -164,7 +164,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
|
||||
val rdd = sc.parallelize(1 to 10, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
testSubmitJob(sc, rdd,
|
||||
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION)
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext
|
|||
|
||||
val rdd = sc.parallelize(1 to 10, 4)
|
||||
.barrier()
|
||||
.mapPartitions((iter, context) => iter)
|
||||
.mapPartitions(iter => iter)
|
||||
.repartition(2)
|
||||
.map(x => x + 1)
|
||||
testSubmitJob(sc, rdd,
|
||||
|
|
|
@ -632,7 +632,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
|
|||
val conf = new SparkConf().setAppName("test").setMaster("local[2]")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// If we don't get the expected taskInfos, the job shall abort due to stage failure.
|
||||
if (context.getTaskInfos().length != 2) {
|
||||
throw new SparkException("Expected taksInfos length is 2, actual length is " +
|
||||
|
@ -654,7 +655,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
|
|||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// If we don't get the expected taskInfos, the job shall abort due to stage failure.
|
||||
if (context.getTaskInfos().length != 2) {
|
||||
throw new SparkException("Expected taksInfos length is 2, actual length is " +
|
||||
|
|
|
@ -25,19 +25,19 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext {
|
|||
val rdd = sc.parallelize(1 to 10, 4)
|
||||
assert(rdd.isBarrier() === false)
|
||||
|
||||
val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter)
|
||||
val rdd2 = rdd.barrier().mapPartitions(iter => iter)
|
||||
assert(rdd2.isBarrier() === true)
|
||||
}
|
||||
|
||||
test("create an RDDBarrier in the middle of a chain of RDDs") {
|
||||
val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2)
|
||||
val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).map(x => (x, x + 1))
|
||||
val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1))
|
||||
assert(rdd2.isBarrier() === true)
|
||||
}
|
||||
|
||||
test("RDDBarrier with shuffle") {
|
||||
val rdd = sc.parallelize(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).repartition(2)
|
||||
val rdd2 = rdd.barrier().mapPartitions(iter => iter).repartition(2)
|
||||
assert(rdd2.isBarrier() === false)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
|
|||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// Sleep for a random time before global sync.
|
||||
Thread.sleep(Random.nextInt(1000))
|
||||
context.barrier()
|
||||
|
@ -49,7 +50,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
|
|||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// Sleep for a random time before global sync.
|
||||
Thread.sleep(Random.nextInt(1000))
|
||||
context.barrier()
|
||||
|
@ -79,7 +81,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
|
|||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// Task 3 shall sleep 2000ms to ensure barrier() call timeout
|
||||
if (context.taskAttemptId == 3) {
|
||||
Thread.sleep(2000)
|
||||
|
@ -103,7 +106,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
|
|||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
if (context.taskAttemptId != 0) {
|
||||
context.barrier()
|
||||
}
|
||||
|
@ -125,7 +129,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
|
|||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { (it, context) =>
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
try {
|
||||
if (context.taskAttemptId == 0) {
|
||||
// Due to some non-obvious reason, the code can trigger an Exception and skip the
|
||||
|
|
|
@ -1062,7 +1062,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
|
|||
}
|
||||
|
||||
test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") {
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it)
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
|
||||
|
@ -1091,7 +1091,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
|
|||
}
|
||||
|
||||
test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") {
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it)
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
|
||||
|
|
Loading…
Reference in a new issue