Unit tests for shuffle operations. Fixes #33.

This commit is contained in:
Matei Zaharia 2010-11-12 16:12:14 -08:00
parent 7b25ab87af
commit d0a9966555

View file

@ -0,0 +1,119 @@
package spark
import org.scalatest.FunSuite
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import SparkContext._
class ShuffleSuite extends FunSuite {
test("groupByKey") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with duplicates") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with many output partitions") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey(10).collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("reduceByKey") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with collectAsMap") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collectAsMap()
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
}
test("reduceByKey with many output partitons") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("join") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (2, 'x')),
(2, (1, 'y')),
(2, (1, 'z'))
))
}
test("join all-to-all") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 6)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (1, 'y')),
(1, (2, 'x')),
(1, (2, 'y')),
(1, (3, 'x')),
(1, (3, 'y'))
))
}
test("join with no matches") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
}
test("join with many output partitions") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2, 10).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (2, 'x')),
(2, (1, 'y')),
(2, (1, 'z'))
))
}
}