Use java.util.HashMap in shuffles
This commit is contained in:
parent
d6ec664b48
commit
c40e766368
|
@ -3,7 +3,7 @@ package spark
|
|||
import java.io.BufferedOutputStream
|
||||
import java.io.FileOutputStream
|
||||
import java.io.ObjectOutputStream
|
||||
import scala.collection.mutable.HashMap
|
||||
import java.util.{HashMap => JHashMap}
|
||||
|
||||
|
||||
class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
|
||||
|
@ -14,21 +14,27 @@ extends DAGTask[String](stageId) with Logging {
|
|||
val numOutputSplits = dep.partitioner.numPartitions
|
||||
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
|
||||
val partitioner = dep.partitioner.asInstanceOf[Partitioner]
|
||||
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
|
||||
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
|
||||
for (elem <- rdd.iterator(split)) {
|
||||
val (k, v) = elem.asInstanceOf[(Any, Any)]
|
||||
var bucketId = partitioner.getPartition(k)
|
||||
val bucket = buckets(bucketId)
|
||||
bucket(k) = bucket.get(k) match {
|
||||
case Some(c) => aggregator.mergeValue(c, v)
|
||||
case None => aggregator.createCombiner(v)
|
||||
var existing = bucket.get(k)
|
||||
if (existing == null) {
|
||||
bucket.put(k, aggregator.createCombiner(v))
|
||||
} else {
|
||||
bucket.put(k, aggregator.mergeValue(existing, v))
|
||||
}
|
||||
}
|
||||
val ser = SparkEnv.get.serializer.newInstance()
|
||||
for (i <- 0 until numOutputSplits) {
|
||||
val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
|
||||
val out = ser.outputStream(new BufferedOutputStream(new FileOutputStream(file)))
|
||||
buckets(i).foreach(pair => out.writeObject(pair))
|
||||
val iter = buckets(i).entrySet().iterator()
|
||||
while (iter.hasNext()) {
|
||||
val entry = iter.next()
|
||||
out.writeObject((entry.getKey, entry.getValue))
|
||||
}
|
||||
// TODO: have some kind of EOF marker
|
||||
out.close()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package spark
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import java.util.{HashMap => JHashMap}
|
||||
|
||||
|
||||
class ShuffledRDDSplit(val idx: Int) extends Split {
|
||||
|
@ -27,15 +27,26 @@ extends RDD[(K, C)](parent.context) {
|
|||
override val dependencies = List(dep)
|
||||
|
||||
override def compute(split: Split): Iterator[(K, C)] = {
|
||||
val combiners = new HashMap[K, C]
|
||||
val combiners = new JHashMap[K, C]
|
||||
def mergePair(k: K, c: C) {
|
||||
combiners(k) = combiners.get(k) match {
|
||||
case Some(oldC) => aggregator.mergeCombiners(oldC, c)
|
||||
case None => c
|
||||
val oldC = combiners.get(k)
|
||||
if (oldC == null) {
|
||||
combiners.put(k, c)
|
||||
} else {
|
||||
combiners.put(k, aggregator.mergeCombiners(oldC, c))
|
||||
}
|
||||
}
|
||||
val fetcher = SparkEnv.get.shuffleFetcher
|
||||
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
|
||||
combiners.iterator
|
||||
return new Iterator[(K, C)] {
|
||||
var iter = combiners.entrySet().iterator()
|
||||
|
||||
def hasNext(): Boolean = iter.hasNext()
|
||||
|
||||
def next(): (K, C) = {
|
||||
val entry = iter.next()
|
||||
(entry.getKey, entry.getValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue