Added a unit test for cross-partition balancing in sort, and changes to
RangePartitioner to make it pass. It turns out that the first partition was always kind of small due to how we picked partition boundaries.
This commit is contained in:
parent
508221b8e6
commit
6da2bcdba1
|
@ -35,35 +35,41 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
|
|||
private val ascending: Boolean = true)
|
||||
extends Partitioner {
|
||||
|
||||
// An array of upper bounds for the first (partitions - 1) partitions
|
||||
private val rangeBounds: Array[K] = {
|
||||
val rddSize = rdd.count()
|
||||
val maxSampleSize = partitions * 10.0
|
||||
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
|
||||
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
|
||||
if (rddSample.length == 0) {
|
||||
if (partitions == 1) {
|
||||
Array()
|
||||
} else {
|
||||
val bounds = new Array[K](partitions)
|
||||
for (i <- 0 until partitions) {
|
||||
bounds(i) = rddSample(i * rddSample.length / partitions)
|
||||
val rddSize = rdd.count()
|
||||
val maxSampleSize = partitions * 10.0
|
||||
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
|
||||
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
|
||||
if (rddSample.length == 0) {
|
||||
Array()
|
||||
} else {
|
||||
val bounds = new Array[K](partitions - 1)
|
||||
for (i <- 0 until partitions - 1) {
|
||||
val index = (rddSample.length - 1) * (i + 1) / partitions
|
||||
bounds(i) = rddSample(index)
|
||||
}
|
||||
bounds
|
||||
}
|
||||
bounds
|
||||
}
|
||||
}
|
||||
|
||||
def numPartitions = rangeBounds.length
|
||||
def numPartitions = partitions
|
||||
|
||||
def getPartition(key: Any): Int = {
|
||||
// TODO: Use a binary search here if number of partitions is large
|
||||
val k = key.asInstanceOf[K]
|
||||
var partition = 0
|
||||
while (partition < rangeBounds.length - 1 && k > rangeBounds(partition)) {
|
||||
while (partition < rangeBounds.length && k > rangeBounds(partition)) {
|
||||
partition += 1
|
||||
}
|
||||
if (ascending) {
|
||||
partition
|
||||
} else {
|
||||
rangeBounds.length - 1 - partition
|
||||
rangeBounds.length - partition
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -261,6 +261,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
|
|||
.map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
|
||||
.saveAsSequenceFile(path)
|
||||
}
|
||||
|
||||
/** A private method for tests, to look at the contents of each partition */
|
||||
private[spark] def collectPartitions(): Array[Array[T]] = {
|
||||
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
class MappedRDD[U: ClassManifest, T: ClassManifest](
|
||||
|
|
|
@ -2,54 +2,86 @@ package spark
|
|||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.BeforeAndAfter
|
||||
import org.scalatest.matchers.ShouldMatchers
|
||||
import SparkContext._
|
||||
|
||||
class SortingSuite extends FunSuite with BeforeAndAfter {
|
||||
class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
|
||||
|
||||
var sc: SparkContext = _
|
||||
|
||||
after {
|
||||
if(sc != null) {
|
||||
if (sc != null) {
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("sortByKey") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
|
||||
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
|
||||
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
|
||||
}
|
||||
|
||||
test("sortLargeArray") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
test("large array") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
}
|
||||
|
||||
test("sortDescending") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr)
|
||||
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
|
||||
test("sort descending") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr)
|
||||
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
|
||||
}
|
||||
|
||||
test("morePartitionsThanElements") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 30)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
test("more partitions than elements") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 30)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
}
|
||||
|
||||
test("emptyRDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = new Array[(Int, Int)](0)
|
||||
val pairs = sc.parallelize(pairArr)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
test("empty RDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairArr = new Array[(Int, Int)](0)
|
||||
val pairs = sc.parallelize(pairArr)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
}
|
||||
|
||||
test("partition balancing") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairArr = (1 to 1000).map(x => (x, x)).toArray
|
||||
val sorted = sc.parallelize(pairArr, 4).sortByKey()
|
||||
assert(sorted.collect() === pairArr.sortBy(_._1))
|
||||
val partitions = sorted.collectPartitions()
|
||||
logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
|
||||
partitions(0).length should be > 150
|
||||
partitions(1).length should be > 150
|
||||
partitions(2).length should be > 150
|
||||
partitions(3).length should be > 150
|
||||
partitions(0).last should be < partitions(1).head
|
||||
partitions(1).last should be < partitions(2).head
|
||||
partitions(2).last should be < partitions(3).head
|
||||
}
|
||||
|
||||
test("partition balancing for descending sort") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairArr = (1 to 1000).map(x => (x, x)).toArray
|
||||
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
|
||||
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
|
||||
val partitions = sorted.collectPartitions()
|
||||
logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
|
||||
partitions(0).length should be > 150
|
||||
partitions(1).length should be > 150
|
||||
partitions(2).length should be > 150
|
||||
partitions(3).length should be > 150
|
||||
partitions(0).last should be > partitions(1).head
|
||||
partitions(1).last should be > partitions(2).head
|
||||
partitions(2).last should be > partitions(3).head
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue