Initial work on union operation.

This commit is contained in:
Matei Zaharia 2010-06-18 12:54:33 -07:00
parent b54198819e
commit 323571a177

View file

@ -72,6 +72,11 @@ abstract class RDD[T: ClassManifest, Split](
def count(): Long =
try { map(x => 1L).reduce(_+_) }
catch { case e: UnsupportedOperationException => 0L }
def union[OtherSplit](other: RDD[T, OtherSplit]) =
new UnionRDD(sc, this, other)
def ++[OtherSplit](other: RDD[T, OtherSplit]) = this.union(other)
}
@serializable
@ -196,3 +201,35 @@ private object CachedRDD {
// Remembers which splits are currently being loaded (on workers)
val loading = new HashSet[String]
}
@serializable
abstract class UnionSplit[T: ClassManifest] {
def iterator(): Iterator[T]
def prefers(offer: SlaveOffer): Boolean
}
@serializable
class UnionSplitImpl[T: ClassManifest, Split](
rdd: RDD[T, Split], split: Split)
extends UnionSplit[T] {
override def iterator() = rdd.iterator(split)
override def prefers(offer: SlaveOffer) = rdd.prefers(split, offer)
}
@serializable
class UnionRDD[T: ClassManifest, Split1, Split2](
sc: SparkContext, rdd1: RDD[T, Split1], rdd2: RDD[T, Split2])
extends RDD[T, UnionSplit[T]](sc) {
@transient val splits_ : Array[UnionSplit[T]] = {
val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s))
val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s))
(a1 ++ a2).toArray
}
override def splits = splits_
override def iterator(s: UnionSplit[T]): Iterator[T] = s.iterator()
override def prefers(s: UnionSplit[T], o: SlaveOffer) = s.prefers(o)
}