Use java.util.HashMap in shuffles

This commit is contained in:
Matei Zaharia 2012-02-06 19:20:25 -08:00
parent d6ec664b48
commit c40e766368
2 changed files with 29 additions and 12 deletions

View file

@ -3,7 +3,7 @@ package spark
import java.io.BufferedOutputStream
import java.io.FileOutputStream
import java.io.ObjectOutputStream
import scala.collection.mutable.HashMap
import java.util.{HashMap => JHashMap}
class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
@ -14,21 +14,27 @@ extends DAGTask[String](stageId) with Logging {
val numOutputSplits = dep.partitioner.numPartitions
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner.asInstanceOf[Partitioner]
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
var bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
bucket(k) = bucket.get(k) match {
case Some(c) => aggregator.mergeValue(c, v)
case None => aggregator.createCombiner(v)
var existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
}
}
val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
val out = ser.outputStream(new BufferedOutputStream(new FileOutputStream(file)))
buckets(i).foreach(pair => out.writeObject(pair))
val iter = buckets(i).entrySet().iterator()
while (iter.hasNext()) {
val entry = iter.next()
out.writeObject((entry.getKey, entry.getValue))
}
// TODO: have some kind of EOF marker
out.close()
}

View file

@ -1,6 +1,6 @@
package spark
import scala.collection.mutable.HashMap
import java.util.{HashMap => JHashMap}
class ShuffledRDDSplit(val idx: Int) extends Split {
@ -27,15 +27,26 @@ extends RDD[(K, C)](parent.context) {
override val dependencies = List(dep)
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new HashMap[K, C]
val combiners = new JHashMap[K, C]
def mergePair(k: K, c: C) {
combiners(k) = combiners.get(k) match {
case Some(oldC) => aggregator.mergeCombiners(oldC, c)
case None => c
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
}
}
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
combiners.iterator
return new Iterator[(K, C)] {
var iter = combiners.entrySet().iterator()
def hasNext(): Boolean = iter.hasNext()
def next(): (K, C) = {
val entry = iter.next()
(entry.getKey, entry.getValue)
}
}
}
}