Support preservesPartitioning in RDD.zipPartitions

This commit is contained in:
Ankur Dave 2013-11-23 02:32:37 -08:00
parent 18ce7e940b
commit ad56ae7bfd
2 changed files with 32 additions and 10 deletions

View file

@ -545,20 +545,35 @@ abstract class RDD[T: ClassManifest](
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
def zipPartitions[B: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], preservesPartitioning: Boolean)
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning)
def zipPartitions[B: ClassManifest, V: ClassManifest]
(rdd2: RDD[B])
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false)
def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean)
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning)
def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], rdd3: RDD[C])
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false)
def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean)
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning)
def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D])
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false)
// Actions (launch a job to return a value to the user program)

View file

@ -39,9 +39,13 @@ private[spark] class ZippedPartitionsPartition(
abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
sc: SparkContext,
var rdds: Seq[RDD[_]])
var rdds: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
override val partitioner =
if (preservesPartitioning) firstParent[Any].partitioner else None
override def getPartitions: Array[Partition] = {
val sizes = rdds.map(x => x.partitions.size)
if (!sizes.forall(x => x == sizes(0))) {
@ -76,8 +80,9 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]
sc: SparkContext,
f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B])
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
var rdd2: RDD[B],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@ -97,8 +102,9 @@ class ZippedPartitionsRDD3
f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C])
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
var rdd3: RDD[C],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@ -122,8 +128,9 @@ class ZippedPartitionsRDD4
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
var rdd4: RDD[D])
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
var rdd4: RDD[D],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions