Added a unit test for cross-partition balancing in sort, and changes to

RangePartitioner to make it pass. It turns out that the first partition
was always kind of small due to how we picked partition boundaries.
This commit is contained in:
Matei Zaharia 2012-08-03 16:37:35 -04:00
parent 508221b8e6
commit 6da2bcdba1
3 changed files with 84 additions and 41 deletions

View file

@ -35,35 +35,41 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
private val ascending: Boolean = true)
extends Partitioner {
// An array of upper bounds for the first (partitions - 1) partitions
private val rangeBounds: Array[K] = {
val rddSize = rdd.count()
val maxSampleSize = partitions * 10.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) {
if (partitions == 1) {
Array()
} else {
val bounds = new Array[K](partitions)
for (i <- 0 until partitions) {
bounds(i) = rddSample(i * rddSample.length / partitions)
val rddSize = rdd.count()
val maxSampleSize = partitions * 10.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) {
Array()
} else {
val bounds = new Array[K](partitions - 1)
for (i <- 0 until partitions - 1) {
val index = (rddSample.length - 1) * (i + 1) / partitions
bounds(i) = rddSample(index)
}
bounds
}
bounds
}
}
def numPartitions = rangeBounds.length
def numPartitions = partitions
def getPartition(key: Any): Int = {
// TODO: Use a binary search here if number of partitions is large
val k = key.asInstanceOf[K]
var partition = 0
while (partition < rangeBounds.length - 1 && k > rangeBounds(partition)) {
while (partition < rangeBounds.length && k > rangeBounds(partition)) {
partition += 1
}
if (ascending) {
partition
} else {
rangeBounds.length - 1 - partition
rangeBounds.length - partition
}
}

View file

@ -261,6 +261,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
.map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
.saveAsSequenceFile(path)
}
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
}
class MappedRDD[U: ClassManifest, T: ClassManifest](

View file

@ -2,54 +2,86 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
class SortingSuite extends FunSuite with BeforeAndAfter {
class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
var sc: SparkContext = _
after {
if(sc != null) {
if (sc != null) {
sc.stop()
}
}
test("sortByKey") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
test("sortLargeArray") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
test("large array") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("sortDescending") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
test("sort descending") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
test("morePartitionsThanElements") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
test("more partitions than elements") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("emptyRDD") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
test("empty RDD") {
sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("partition balancing") {
sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey()
assert(sorted.collect() === pairArr.sortBy(_._1))
val partitions = sorted.collectPartitions()
logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
partitions(0).length should be > 150
partitions(1).length should be > 150
partitions(2).length should be > 150
partitions(3).length should be > 150
partitions(0).last should be < partitions(1).head
partitions(1).last should be < partitions(2).head
partitions(2).last should be < partitions(3).head
}
test("partition balancing for descending sort") {
sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
val partitions = sorted.collectPartitions()
logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
partitions(0).length should be > 150
partitions(1).length should be > 150
partitions(2).length should be > 150
partitions(3).length should be > 150
partitions(0).last should be > partitions(1).head
partitions(1).last should be > partitions(2).head
partitions(2).last should be > partitions(3).head
}
}