Added pruntSplits method to RDD.
This commit is contained in:
parent
7c3a1bddb7
commit
eb222b7206
|
@ -40,6 +40,7 @@ import spark.rdd.MapPartitionsRDD
|
|||
import spark.rdd.MapPartitionsWithSplitRDD
|
||||
import spark.rdd.PipedRDD
|
||||
import spark.rdd.SampledRDD
|
||||
import spark.rdd.SplitsPruningRDD
|
||||
import spark.rdd.UnionRDD
|
||||
import spark.rdd.ZippedRDD
|
||||
import spark.storage.StorageLevel
|
||||
|
@ -543,6 +544,15 @@ abstract class RDD[T: ClassManifest](
|
|||
map(x => (f(x), x))
|
||||
}
|
||||
|
||||
/**
|
||||
* Prune splits (partitions) so Spark can avoid launching tasks on
|
||||
* all splits. An example use case: If we know the RDD is partitioned by range,
|
||||
* and the execution DAG has a filter on the key, we can avoid launching tasks
|
||||
* on splits that don't have the range covering the key.
|
||||
*/
|
||||
def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] =
|
||||
new SplitsPruningRDD(this, splitsFilterFunc)
|
||||
|
||||
/** A private method for tests, to look at the contents of each partition */
|
||||
private[spark] def collectPartitions(): Array[Array[T]] = {
|
||||
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
|
||||
|
|
24
core/src/main/scala/spark/rdd/SplitsPruningRDD.scala
Normal file
24
core/src/main/scala/spark/rdd/SplitsPruningRDD.scala
Normal file
|
@ -0,0 +1,24 @@
|
|||
package spark.rdd
|
||||
|
||||
import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
|
||||
|
||||
/**
|
||||
* A RDD used to prune RDD splits so we can avoid launching tasks on
|
||||
* all splits. An example use case: If we know the RDD is partitioned by range,
|
||||
* and the execution DAG has a filter on the key, we can avoid launching tasks
|
||||
* on splits that don't have the range covering the key.
|
||||
*/
|
||||
class SplitsPruningRDD[T: ClassManifest](
|
||||
prev: RDD[T],
|
||||
@transient splitsFilterFunc: Int => Boolean)
|
||||
extends RDD[T](prev) {
|
||||
|
||||
@transient
|
||||
val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index))
|
||||
|
||||
override def compute(split: Split, context: TaskContext) = prev.iterator(split, context)
|
||||
|
||||
override protected def getSplits = _splits
|
||||
|
||||
override val partitioner = prev.partitioner
|
||||
}
|
|
@ -1,11 +1,9 @@
|
|||
package spark
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
import spark.SparkContext._
|
||||
import spark.rdd.CoalescedRDD
|
||||
import SparkContext._
|
||||
|
||||
class RDDSuite extends FunSuite with BeforeAndAfter {
|
||||
|
||||
|
@ -104,7 +102,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
|
|||
}
|
||||
|
||||
test("caching with failures") {
|
||||
sc = new SparkContext("local", "test")
|
||||
sc = new SparkContext("local", "test")
|
||||
val onlySplit = new Split { override def index: Int = 0 }
|
||||
var shouldFail = true
|
||||
val rdd = new RDD[Int](sc, Nil) {
|
||||
|
@ -136,8 +134,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
|
|||
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
|
||||
|
||||
// Check that the narrow dependency is also specified correctly
|
||||
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
|
||||
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
|
||||
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
|
||||
List(0, 1, 2, 3, 4))
|
||||
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
|
||||
List(5, 6, 7, 8, 9))
|
||||
|
||||
val coalesced2 = new CoalescedRDD(data, 3)
|
||||
assert(coalesced2.collect().toList === (1 to 10).toList)
|
||||
|
@ -168,4 +168,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
|
|||
nums.zip(sc.parallelize(1 to 4, 1)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
test("split pruning") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val data = sc.parallelize(1 to 10, 10)
|
||||
// Note that split number starts from 0, so > 8 means only 10th partition left.
|
||||
val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect
|
||||
assert(prunedData.size == 1 && prunedData(0) == 10)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue