spark-instrumented-optimizer/core/src/main/scala/spark/CoGroupedRDD.scala

95 lines
2.9 KiB
Scala
Raw Normal View History

2011-03-06 22:27:03 -05:00
package spark
import java.net.URL
import java.io.EOFException
import java.io.ObjectInputStream
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@serializable
sealed trait CoGroupSplitDep
case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
@serializable
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep])
extends Split {
override val index = idx
override def hashCode(): Int = idx
}
@serializable
2011-03-07 02:38:16 -05:00
class CoGroupAggregator extends Aggregator[Any, Any, ArrayBuffer[Any]] (
2011-03-06 22:27:03 -05:00
{ x => ArrayBuffer(x) },
{ (b, x) => b += x },
{ (b1, b2) => b1 ++ b2 }
)
2011-03-07 02:38:16 -05:00
class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
2011-03-07 02:38:16 -05:00
val aggr = new CoGroupAggregator
2011-03-06 22:27:03 -05:00
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
if (rdd.partitioner == Some(part)) {
2011-03-07 02:38:16 -05:00
logInfo("Adding one-to-one dependency with " + rdd)
2011-03-06 22:27:03 -05:00
deps += new OneToOneDependency(rdd)
} else {
2011-03-07 02:38:16 -05:00
logInfo("Adding shuffle dependency with " + rdd)
deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
2011-03-06 22:27:03 -05:00
context.newShuffleId, rdd, aggr, part)
}
}
deps.toList
}
@transient val splits_ : Array[Split] = {
val firstRdd = rdds.head
2011-03-06 22:27:03 -05:00
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
dependencies(j) match {
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
}
}.toList)
}
array
}
override def splits = splits_
2011-03-07 02:38:16 -05:00
override val partitioner = Some(part)
2011-03-06 22:27:03 -05:00
override def preferredLocations(s: Split) = Nil
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
}
2011-03-07 02:38:16 -05:00
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
2011-03-06 22:27:03 -05:00
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
// Read them from the parent
for ((k: K, v) <- rdd.iterator(itsSplit)) {
2011-03-07 02:38:16 -05:00
getSeq(k)(depNum) += v
2011-03-06 22:27:03 -05:00
}
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
def mergePair(k: K, vs: Seq[Any]) {
val mySeq = getSeq(k)
for (v <- vs)
mySeq(depNum) += v
2011-03-06 22:27:03 -05:00
}
new SimpleShuffleFetcher().fetch[K, Seq[Any]](shuffleId, split.index, mergePair)
2011-03-06 22:27:03 -05:00
}
}
map.iterator
}
}