From c40e766368112cf1709b286799731dbf64fe2b51 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 6 Feb 2012 19:20:25 -0800 Subject: [PATCH] Use java.util.HashMap in shuffles --- .../src/main/scala/spark/ShuffleMapTask.scala | 18 ++++++++++----- core/src/main/scala/spark/ShuffledRDD.scala | 23 ++++++++++++++----- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala index eb6a5e2df3..7b08a21fca 100644 --- a/core/src/main/scala/spark/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/ShuffleMapTask.scala @@ -3,7 +3,7 @@ package spark import java.io.BufferedOutputStream import java.io.FileOutputStream import java.io.ObjectOutputStream -import scala.collection.mutable.HashMap +import java.util.{HashMap => JHashMap} class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String]) @@ -14,21 +14,27 @@ extends DAGTask[String](stageId) with Logging { val numOutputSplits = dep.partitioner.numPartitions val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] val partitioner = dep.partitioner.asInstanceOf[Partitioner] - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) + val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) for (elem <- rdd.iterator(split)) { val (k, v) = elem.asInstanceOf[(Any, Any)] var bucketId = partitioner.getPartition(k) val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => aggregator.mergeValue(c, v) - case None => aggregator.createCombiner(v) + var existing = bucket.get(k) + if (existing == null) { + bucket.put(k, aggregator.createCombiner(v)) + } else { + bucket.put(k, aggregator.mergeValue(existing, v)) } } val ser = SparkEnv.get.serializer.newInstance() for (i <- 0 until numOutputSplits) { val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i) val out = ser.outputStream(new BufferedOutputStream(new FileOutputStream(file))) - buckets(i).foreach(pair => out.writeObject(pair)) + val iter = buckets(i).entrySet().iterator() + while (iter.hasNext()) { + val entry = iter.next() + out.writeObject((entry.getKey, entry.getValue)) + } // TODO: have some kind of EOF marker out.close() } diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 4ab1958ea1..9cada0617e 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -1,6 +1,6 @@ package spark -import scala.collection.mutable.HashMap +import java.util.{HashMap => JHashMap} class ShuffledRDDSplit(val idx: Int) extends Split { @@ -27,15 +27,26 @@ extends RDD[(K, C)](parent.context) { override val dependencies = List(dep) override def compute(split: Split): Iterator[(K, C)] = { - val combiners = new HashMap[K, C] + val combiners = new JHashMap[K, C] def mergePair(k: K, c: C) { - combiners(k) = combiners.get(k) match { - case Some(oldC) => aggregator.mergeCombiners(oldC, c) - case None => c + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, c) + } else { + combiners.put(k, aggregator.mergeCombiners(oldC, c)) } } val fetcher = SparkEnv.get.shuffleFetcher fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair) - combiners.iterator + return new Iterator[(K, C)] { + var iter = combiners.entrySet().iterator() + + def hasNext(): Boolean = iter.hasNext() + + def next(): (K, C) = { + val entry = iter.next() + (entry.getKey, entry.getValue) + } + } } }