Unit tests for shuffle operations. Fixes #33.
This commit is contained in:
parent
7b25ab87af
commit
d0a9966555
119
src/test/spark/ShuffleSuite.scala
Normal file
119
src/test/spark/ShuffleSuite.scala
Normal file
|
@ -0,0 +1,119 @@
|
|||
package spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.prop.Checkers
|
||||
import org.scalacheck.Arbitrary._
|
||||
import org.scalacheck.Gen
|
||||
import org.scalacheck.Prop._
|
||||
|
||||
import SparkContext._
|
||||
|
||||
class ShuffleSuite extends FunSuite {
|
||||
test("groupByKey") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with duplicates") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with many output partitions") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey(10).collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("reduceByKey") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("reduceByKey with collectAsMap") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_).collectAsMap()
|
||||
assert(sums.size === 2)
|
||||
assert(sums(1) === 7)
|
||||
assert(sums(2) === 1)
|
||||
}
|
||||
|
||||
test("reduceByKey with many output partitons") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_, 10).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("join") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (2, 'x')),
|
||||
(2, (1, 'y')),
|
||||
(2, (1, 'z'))
|
||||
))
|
||||
}
|
||||
|
||||
test("join all-to-all") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 6)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (1, 'y')),
|
||||
(1, (2, 'x')),
|
||||
(1, (2, 'y')),
|
||||
(1, (3, 'x')),
|
||||
(1, (3, 'y'))
|
||||
))
|
||||
}
|
||||
|
||||
test("join with no matches") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 0)
|
||||
}
|
||||
|
||||
test("join with many output partitions") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.join(rdd2, 10).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (2, 'x')),
|
||||
(2, (1, 'y')),
|
||||
(2, (1, 'z'))
|
||||
))
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue