Created a PruneDependency to properly assign dependency for
SplitsPruningRDD.
This commit is contained in:
parent
45cd50d5fe
commit
636e912f32
|
@ -5,6 +5,7 @@ package spark
|
|||
*/
|
||||
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
|
||||
|
||||
|
||||
/**
|
||||
* Base class for dependencies where each partition of the parent RDD is used by at most one
|
||||
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
|
||||
|
@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
|
|||
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
|
||||
/**
|
||||
* Get the parent partitions for a child partition.
|
||||
* @param outputPartition a partition of the child RDD
|
||||
* @param partitionId a partition of the child RDD
|
||||
* @return the partitions of the parent RDD that the child partition depends upon
|
||||
*/
|
||||
def getParents(outputPartition: Int): Seq[Int]
|
||||
def getParents(partitionId: Int): Seq[Int]
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents a dependency on the output of a shuffle stage.
|
||||
* @param shuffleId the shuffle id
|
||||
|
@ -32,6 +34,7 @@ class ShuffleDependency[K, V](
|
|||
val shuffleId: Int = rdd.context.newShuffleId()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
|
||||
*/
|
||||
|
@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
|
|||
override def getParents(partitionId: Int) = List(partitionId)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
|
||||
* @param rdd the parent RDD
|
||||
|
@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
|
|||
*/
|
||||
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
|
||||
extends NarrowDependency[T](rdd) {
|
||||
|
||||
|
||||
override def getParents(partitionId: Int) = {
|
||||
if (partitionId >= outStart && partitionId < outStart + length) {
|
||||
List(partitionId - outStart + inStart)
|
||||
|
@ -57,3 +61,17 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents a dependency between the SplitsPruningRDD and its parent. In this
|
||||
* case, the child RDD contains a subset of splits of the parents'.
|
||||
*/
|
||||
class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean)
|
||||
extends NarrowDependency[T](rdd) {
|
||||
|
||||
@transient
|
||||
val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index))
|
||||
|
||||
override def getParents(partitionId: Int) = List(splits(partitionId).index)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package spark.rdd
|
||||
|
||||
import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
|
||||
import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
|
||||
|
||||
/**
|
||||
* A RDD used to prune RDD splits so we can avoid launching tasks on
|
||||
|
@ -11,12 +11,12 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
|
|||
class SplitsPruningRDD[T: ClassManifest](
|
||||
prev: RDD[T],
|
||||
@transient splitsFilterFunc: Int => Boolean)
|
||||
extends RDD[T](prev) {
|
||||
extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) {
|
||||
|
||||
@transient
|
||||
val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index))
|
||||
val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits
|
||||
|
||||
override def compute(split: Split, context: TaskContext) = prev.iterator(split, context)
|
||||
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
|
||||
|
||||
override protected def getSplits = _splits
|
||||
|
||||
|
|
Loading…
Reference in a new issue