95 lines
2.9 KiB
Scala
95 lines
2.9 KiB
Scala
package spark
|
|
|
|
import java.net.URL
|
|
import java.io.EOFException
|
|
import java.io.ObjectInputStream
|
|
import scala.collection.mutable.ArrayBuffer
|
|
import scala.collection.mutable.HashMap
|
|
|
|
@serializable
|
|
sealed trait CoGroupSplitDep
|
|
case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
|
|
case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
|
|
|
|
@serializable
|
|
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep])
|
|
extends Split {
|
|
override val index = idx
|
|
override def hashCode(): Int = idx
|
|
}
|
|
|
|
@serializable
|
|
class CoGroupAggregator extends Aggregator[Any, Any, ArrayBuffer[Any]] (
|
|
{ x => ArrayBuffer(x) },
|
|
{ (b, x) => b += x },
|
|
{ (b1, b2) => b1 ++ b2 }
|
|
)
|
|
|
|
class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
|
|
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
|
|
val aggr = new CoGroupAggregator
|
|
|
|
override val dependencies = {
|
|
val deps = new ArrayBuffer[Dependency[_]]
|
|
for ((rdd, index) <- rdds.zipWithIndex) {
|
|
if (rdd.partitioner == Some(part)) {
|
|
logInfo("Adding one-to-one dependency with " + rdd)
|
|
deps += new OneToOneDependency(rdd)
|
|
} else {
|
|
logInfo("Adding shuffle dependency with " + rdd)
|
|
deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
|
|
context.newShuffleId, rdd, aggr, part)
|
|
}
|
|
}
|
|
deps.toList
|
|
}
|
|
|
|
@transient val splits_ : Array[Split] = {
|
|
val firstRdd = rdds.head
|
|
val array = new Array[Split](part.numPartitions)
|
|
for (i <- 0 until array.size) {
|
|
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
|
|
dependencies(j) match {
|
|
case s: ShuffleDependency[_, _, _] =>
|
|
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
|
|
case _ =>
|
|
new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
|
|
}
|
|
}.toList)
|
|
}
|
|
array
|
|
}
|
|
|
|
override def splits = splits_
|
|
|
|
override val partitioner = Some(part)
|
|
|
|
override def preferredLocations(s: Split) = Nil
|
|
|
|
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
|
|
val split = s.asInstanceOf[CoGroupSplit]
|
|
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
|
|
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
|
|
map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
|
|
}
|
|
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
|
|
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
|
|
// Read them from the parent
|
|
for ((k: K, v) <- rdd.iterator(itsSplit)) {
|
|
getSeq(k)(depNum) += v
|
|
}
|
|
}
|
|
case ShuffleCoGroupSplitDep(shuffleId) => {
|
|
// Read map outputs of shuffle
|
|
def mergePair(k: K, vs: Seq[Any]) {
|
|
val mySeq = getSeq(k)
|
|
for (v <- vs)
|
|
mySeq(depNum) += v
|
|
}
|
|
new SimpleShuffleFetcher().fetch[K, Seq[Any]](shuffleId, split.index, mergePair)
|
|
}
|
|
}
|
|
map.iterator
|
|
}
|
|
}
|