Fixed the bug that shuffle serializer is ignored by the new shuffle
block iterators for local blocks. Also added a unit test for that.
This commit is contained in:
parent
dbbedfc535
commit
6ea085169d
|
@ -163,7 +163,7 @@ object BlockFetcherIterator {
|
||||||
// these all at once because they will just memory-map some files, so they won't consume
|
// these all at once because they will just memory-map some files, so they won't consume
|
||||||
// any memory that might exceed our maxBytesInFlight
|
// any memory that might exceed our maxBytesInFlight
|
||||||
for (id <- localBlockIds) {
|
for (id <- localBlockIds) {
|
||||||
getLocal(id) match {
|
getLocalFromDisk(id, serializer) match {
|
||||||
case Some(iter) => {
|
case Some(iter) => {
|
||||||
// Pass 0 as size since it's not in flight
|
// Pass 0 as size since it's not in flight
|
||||||
results.put(new FetchResult(id, 0, () => iter))
|
results.put(new FetchResult(id, 0, () => iter))
|
||||||
|
|
|
@ -99,7 +99,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
||||||
val sums = pairs.reduceByKey(_+_, 10).collect()
|
val sums = pairs.reduceByKey(_+_, 10).collect()
|
||||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("reduceByKey with partitioner") {
|
test("reduceByKey with partitioner") {
|
||||||
sc = new SparkContext("local", "test")
|
sc = new SparkContext("local", "test")
|
||||||
val p = new Partitioner() {
|
val p = new Partitioner() {
|
||||||
|
@ -272,7 +272,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
||||||
}
|
}
|
||||||
// partitionBy so we have a narrow dependency
|
// partitionBy so we have a narrow dependency
|
||||||
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
||||||
// more partitions/no partitioner so a shuffle dependency
|
// more partitions/no partitioner so a shuffle dependency
|
||||||
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
||||||
val c = a.subtract(b)
|
val c = a.subtract(b)
|
||||||
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
|
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
|
||||||
|
@ -298,18 +298,33 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
||||||
}
|
}
|
||||||
// partitionBy so we have a narrow dependency
|
// partitionBy so we have a narrow dependency
|
||||||
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
||||||
// more partitions/no partitioner so a shuffle dependency
|
// more partitions/no partitioner so a shuffle dependency
|
||||||
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
||||||
val c = a.subtractByKey(b)
|
val c = a.subtractByKey(b)
|
||||||
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
|
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
|
||||||
assert(c.partitioner.get === p)
|
assert(c.partitioner.get === p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("shuffle serializer") {
|
||||||
|
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
|
||||||
|
sc = new SparkContext("local-cluster[1,2,512]", "test")
|
||||||
|
val a = sc.parallelize(1 to 10, 2)
|
||||||
|
val b = a.map { x =>
|
||||||
|
(x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
|
||||||
|
}
|
||||||
|
// If the Kryo serializer is not used correctly, the shuffle would fail because the
|
||||||
|
// default Java serializer cannot handle the non serializable class.
|
||||||
|
val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
|
||||||
|
assert(c.count === 10)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object ShuffleSuite {
|
object ShuffleSuite {
|
||||||
|
|
||||||
def mergeCombineException(x: Int, y: Int): Int = {
|
def mergeCombineException(x: Int, y: Int): Int = {
|
||||||
throw new SparkException("Exception for map-side combine.")
|
throw new SparkException("Exception for map-side combine.")
|
||||||
x + y
|
x + y
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class NonJavaSerializableClass(val value: Int)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue