diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index cc3cca2571..18b4a1eca4 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -439,12 +439,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of * the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner. + * + * The number of partitions will be the same as the number of partitions in the largest upstream + * RDD, as this should be least likely to cause out-of-memory errors. */ def defaultPartitioner(rdds: RDD[_]*): Partitioner = { - for (r <- rdds if r.partitioner != None) { + val bySize = rdds.sortBy(_.splits.size).reverse + for (r <- bySize if r.partitioner != None) { return r.partitioner.get } - return new HashPartitioner(self.context.defaultParallelism) + return new HashPartitioner(bySize.head.splits.size) } /** diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index af1107cd19..60db759c25 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -84,10 +84,10 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner) assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner) - assert(grouped2.join(grouped4).partitioner === grouped2.partitioner) - assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped2.partitioner) - assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped2.partitioner) - assert(grouped2.cogroup(grouped4).partitioner === grouped2.partitioner) + assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 3493b9511f..ab7060a1ac 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -211,6 +211,25 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(rdd.keys.collect().toList === List(1, 2)) assert(rdd.values.collect().toList === List("a", "b")) } + + test("default partition size uses split size") { + sc = new SparkContext("local", "test") + // specify 2000 splits + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 splits + val c = b.groupByKey() + assert(c.splits.size === 2000) + } + + test("default partition uses largest partitioner") { + sc = new SparkContext("local", "test") + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.splits.size === 2000) + } } object ShuffleSuite {