Added pruntSplits method to RDD.

This commit is contained in:
Reynold Xin 2013-01-23 15:29:02 -08:00
parent 7c3a1bddb7
commit eb222b7206
3 changed files with 49 additions and 7 deletions

View file

@ -40,6 +40,7 @@ import spark.rdd.MapPartitionsRDD
import spark.rdd.MapPartitionsWithSplitRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.SplitsPruningRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
import spark.storage.StorageLevel
@ -543,6 +544,15 @@ abstract class RDD[T: ClassManifest](
map(x => (f(x), x))
}
/**
* Prune splits (partitions) so Spark can avoid launching tasks on
* all splits. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on splits that don't have the range covering the key.
*/
def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] =
new SplitsPruningRDD(this, splitsFilterFunc)
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)

View file

@ -0,0 +1,24 @@
package spark.rdd
import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
/**
* A RDD used to prune RDD splits so we can avoid launching tasks on
* all splits. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on splits that don't have the range covering the key.
*/
class SplitsPruningRDD[T: ClassManifest](
prev: RDD[T],
@transient splitsFilterFunc: Int => Boolean)
extends RDD[T](prev) {
@transient
val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index))
override def compute(split: Split, context: TaskContext) = prev.iterator(split, context)
override protected def getSplits = _splits
override val partitioner = prev.partitioner
}

View file

@ -1,11 +1,9 @@
package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.{BeforeAndAfter, FunSuite}
import spark.SparkContext._
import spark.rdd.CoalescedRDD
import SparkContext._
class RDDSuite extends FunSuite with BeforeAndAfter {
@ -136,8 +134,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
List(0, 1, 2, 3, 4))
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)
@ -168,4 +168,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
}
test("split pruning") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect
assert(prunedData.size == 1 && prunedData(0) == 10)
}
}