From 4ff9f1fe3bedb5422453838d904112454ddd5675 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 27 Apr 2021 13:05:57 +0000 Subject: [PATCH] [SPARK-35239][SQL] Coalesce shuffle partition should handle empty input RDD ### What changes were proposed in this pull request? Create empty partition for custom shuffle reader if input RDD is empty. ### Why are the changes needed? If input RDD partition is empty then the map output statistics will be null. And if all shuffle stage's input RDD partition is empty, we will skip it and lose the chance to coalesce partition. We can simply create a empty partition for these custom shuffle reader to reduce the partition number. ### Does this PR introduce _any_ user-facing change? Yes, the shuffle partition might be changed in AQE. ### How was this patch tested? add new test. Closes #32362 from ulysses-you/SPARK-35239. Authored-by: ulysses-you Signed-off-by: Wenchen Fan --- .../adaptive/CoalesceShufflePartitions.scala | 29 ++++++++++++------- .../adaptive/ShufflePartitionsUtil.scala | 4 +++ .../adaptive/AdaptiveQueryExecSuite.scala | 15 ++++++++++ 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index bd45863652..d50e32c8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.physical.SinglePartition -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.internal.SQLConf @@ -54,8 +54,21 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) { plan } else { + def insertCustomShuffleReader(partitionSpecs: Seq[ShufflePartitionSpec]): SparkPlan = { + // This transformation adds new nodes, so we must use `transformUp` here. + val stageIds = shuffleStages.map(_.id).toSet + plan.transformUp { + // even for shuffle exchange whose input RDD has 0 partition, we should still update its + // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same + // number of output partitions. + case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => + CustomShuffleReaderExec(stage, partitionSpecs) + } + } + // `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions, // we should skip it when calculating the `partitionStartIndices`. + // If all input RDDs have 0 partition, we create empty partition for every shuffle reader. val validMetrics = shuffleStages.flatMap(_.mapStats) // We may have different pre-shuffle partition numbers, don't reduce shuffle partition number @@ -63,7 +76,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl // partition) and a result of a SortMergeJoin (multiple partitions). val distinctNumPreShufflePartitions = validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { + if (validMetrics.isEmpty) { + insertCustomShuffleReader(ShufflePartitionsUtil.createEmptyPartition() :: Nil) + } else if (distinctNumPreShufflePartitions.length == 1) { // We fall back to Spark default parallelism if the minimum number of coalesced partitions // is not set, so to avoid perf regressions compared to no coalescing. val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM) @@ -77,15 +92,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl if (partitionSpecs.length == distinctNumPreShufflePartitions.head) { plan } else { - // This transformation adds new nodes, so we must use `transformUp` here. - val stageIds = shuffleStages.map(_.id).toSet - plan.transformUp { - // even for shuffle exchange whose input RDD has 0 partition, we should still update its - // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same - // number of output partitions. - case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => - CustomShuffleReaderExec(stage, partitionSpecs) - } + insertCustomShuffleReader(partitionSpecs) } } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index ed92af6adc..a70a5322a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -125,6 +125,10 @@ object ShufflePartitionsUtil extends Logging { partitionSpecs.toSeq } + def createEmptyPartition(): ShufflePartitionSpec = { + CoalescedPartitionSpec(0, 0) + } + /** * Given a list of size, return an array of indices to split the list into multiple partitions, * so that the size sum of each partition is close to the target size. Each index indicates the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 31b6921132..2598d3ba8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1575,4 +1575,19 @@ class AdaptiveQueryExecSuite checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS) } } + + test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") { + withTable("t") { + withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + spark.sql("CREATE TABLE t (c1 int) USING PARQUET") + val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") + assert( + collect(adaptive) { + case c @ CustomShuffleReaderExec(_, partitionSpecs) if partitionSpecs.length == 1 => c + }.length == 1 + ) + } + } + } }