spark-instrumented-optimizer/core/src/main/scala/spark/CoGroupedRDD.scala

112 lines
3.7 KiB
Scala
Raw Normal View History

2011-03-06 22:27:03 -05:00
package spark
import java.net.URL
import java.io.EOFException
import java.io.ObjectInputStream
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@serializable
sealed trait CoGroupSplitDep
case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
@serializable
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep])
extends Split {
override val index = idx
override def hashCode(): Int = idx
}
@serializable
2011-03-07 02:38:16 -05:00
class CoGroupAggregator extends Aggregator[Any, Any, ArrayBuffer[Any]] (
2011-03-06 22:27:03 -05:00
{ x => ArrayBuffer(x) },
{ (b, x) => b += x },
{ (b1, b2) => b1 ++ b2 }
)
2011-03-07 02:38:16 -05:00
class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging {
val aggr = new CoGroupAggregator
2011-03-06 22:27:03 -05:00
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
if (rdd.partitioner == Some(part)) {
2011-03-07 02:38:16 -05:00
logInfo("Adding one-to-one dependency with " + rdd)
2011-03-06 22:27:03 -05:00
deps += new OneToOneDependency(rdd)
} else {
2011-03-07 02:38:16 -05:00
logInfo("Adding shuffle dependency with " + rdd)
deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
2011-03-06 22:27:03 -05:00
context.newShuffleId, rdd, aggr, part)
}
}
deps.toList
}
@transient val splits_ : Array[Split] = {
val firstRdd = rdds.first
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
dependencies(j) match {
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
}
}.toList)
}
array
}
override def splits = splits_
2011-03-07 02:38:16 -05:00
override val partitioner = Some(part)
2011-03-06 22:27:03 -05:00
override def preferredLocations(s: Split) = Nil
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
}
2011-03-07 02:38:16 -05:00
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
2011-03-06 22:27:03 -05:00
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
// Read them from the parent
for ((k: K, v) <- rdd.iterator(itsSplit)) {
2011-03-07 02:38:16 -05:00
getSeq(k)(depNum) += v
2011-03-06 22:27:03 -05:00
}
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
2011-03-07 02:38:16 -05:00
logInfo("Grabbing map outputs for shuffle ID " + shuffleId)
2011-03-06 22:27:03 -05:00
val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
val serverUris = MapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) {
splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
}
for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
for (i <- inputIds) {
2011-03-07 02:38:16 -05:00
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, split.index)
2011-03-06 22:27:03 -05:00
val inputStream = new ObjectInputStream(new URL(url).openStream())
2011-03-07 02:38:16 -05:00
logInfo("Opened stream to " + url)
2011-03-06 22:27:03 -05:00
try {
while (true) {
val (k, vs) = inputStream.readObject().asInstanceOf[(K, Seq[Any])]
val mySeq = getSeq(k)
for (v <- vs)
2011-03-07 02:38:16 -05:00
mySeq(depNum) += v
2011-03-06 22:27:03 -05:00
}
} catch {
case e: EOFException => {}
}
inputStream.close()
}
}
}
}
map.iterator
}
}