From eb222b720647c9e92a867c591cc4914b9a6cb5c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:29:02 -0800 Subject: [PATCH] Added pruntSplits method to RDD. --- core/src/main/scala/spark/RDD.scala | 10 ++++++++ .../scala/spark/rdd/SplitsPruningRDD.scala | 24 +++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 22 +++++++++++------ 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..3d93ff33bb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,6 +40,7 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD +import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -543,6 +544,15 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } + /** + * Prune splits (partitions) so Spark can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ + def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = + new SplitsPruningRDD(this, splitsFilterFunc) + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala new file mode 100644 index 0000000000..74e10265fc --- /dev/null +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD splits so we can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ +class SplitsPruningRDD[T: ClassManifest]( + prev: RDD[T], + @transient splitsFilterFunc: Int => Boolean) + extends RDD[T](prev) { + + @transient + val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + + override protected def getSplits = _splits + + override val partitioner = prev.partitioner +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..03aa2845f4 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,11 +1,9 @@ package spark import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - +import org.scalatest.{BeforeAndAfter, FunSuite} +import spark.SparkContext._ import spark.rdd.CoalescedRDD -import SparkContext._ class RDDSuite extends FunSuite with BeforeAndAfter { @@ -104,7 +102,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } test("caching with failures") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { @@ -136,8 +134,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) @@ -168,4 +168,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter { nums.zip(sc.parallelize(1 to 4, 1)).collect() } } + + test("split pruning") { + sc = new SparkContext("local", "test") + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect + assert(prunedData.size == 1 && prunedData(0) == 10) + } }