Merge pull request #469 from stephenh/samepartitionercombine

If combineByKey is using the same partitioner, skip the shuffle.
This commit is contained in:
Matei Zaharia 2013-02-16 10:07:42 -08:00
commit 9d979fb630
2 changed files with 26 additions and 1 deletions

View file

@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)

View file

@ -1,6 +1,7 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
@ -99,6 +100,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with partitioner") {
sc = new SparkContext("local", "test")
val p = new Partitioner() {
def numPartitions = 2
def getPartition(key: Any) = key.asInstanceOf[Int]
}
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
val sums = pairs.reduceByKey(_+_)
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
assert(sums.partitioner === Some(p))
// count the dependencies to make sure there is only 1 ShuffledRDD
val deps = new HashSet[RDD[_]]()
def visit(r: RDD[_]) {
for (dep <- r.dependencies) {
deps += dep.rdd
visit(dep.rdd)
}
}
visit(sums)
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
test("join") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))