Merge pull request #469 from stephenh/samepartitionercombine
If combineByKey is using the same partitioner, skip the shuffle.
This commit is contained in:
commit
9d979fb630
|
@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
}
|
||||
val aggregator =
|
||||
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||
if (mapSideCombine) {
|
||||
if (self.partitioner == Some(partitioner)) {
|
||||
self.mapPartitions(aggregator.combineValuesByKey(_), true)
|
||||
} else if (mapSideCombine) {
|
||||
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
|
||||
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
|
||||
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package spark
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.matchers.ShouldMatchers
|
||||
|
@ -99,6 +100,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
|||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("reduceByKey with partitioner") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 2
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
|
||||
val sums = pairs.reduceByKey(_+_)
|
||||
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
|
||||
assert(sums.partitioner === Some(p))
|
||||
// count the dependencies to make sure there is only 1 ShuffledRDD
|
||||
val deps = new HashSet[RDD[_]]()
|
||||
def visit(r: RDD[_]) {
|
||||
for (dep <- r.dependencies) {
|
||||
deps += dep.rdd
|
||||
visit(dep.rdd)
|
||||
}
|
||||
}
|
||||
visit(sums)
|
||||
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
|
||||
}
|
||||
|
||||
test("join") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
|
|
Loading…
Reference in a new issue