diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ca52ecafa2..9ea6d2fa2f 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -76,6 +76,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { + if (mapSideCombine) { + require(aggregator.isDefined, "Map-side combine without Aggregator specified!") + } override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff5..edd69715c9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -90,7 +90,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bfb4dc698e..d9fad64f34 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${dependency.serializer.getClass.getName}, does not support object relocation") false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + } else if (dependency.mapSideCombine) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") false } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792..274399b9cc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { @@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter { def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") false } else { val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 55cebe7c8b..f29dac965c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // We support serialized shuffle if we do not need to do map-side aggregation + assert(canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) } test("unsupported shuffle dependencies for serialized shuffle") { @@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform aggregation - assert(!canUseSerializedShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = None, - aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), - mapSideCombine = false - ))) + // We do not support serialized shuffle if we need to do map-side aggregation assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo,