diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index a990700729..d71933ab58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -250,7 +250,12 @@ object EnsureRequirements extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + // SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM + // shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL, + // it is a special shuffle used to rebalance partitions. + // So, here we only remove REPARTITION_BY_COL in AQE. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin) + if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled => def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = { partitioning match { case lower: HashPartitioning if upper.semanticEquals(lower) => true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0ea75996b3..fad6ed104f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.functions._ @@ -420,7 +420,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val inputPlan = ShuffleExchangeExec( partitioning, - DummySparkPlan(outputPartitioning = partitioning)) + DummySparkPlan(outputPartitioning = partitioning), + REPARTITION_BY_COL) val outputPlan = EnsureRequirements.apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index b46cc9f427..80a0c31f56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1526,13 +1526,13 @@ class AdaptiveQueryExecSuite val dfRepartitionWithNum = df.repartition(5, 'b) dfRepartitionWithNum.collect() val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(planWithNum)) + // The top shuffle from repartition is not optimized out. + assert(hasRepartitionShuffle(planWithNum)) val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) assert(bhjWithNum.length == 1) checkNumLocalShuffleReaders(planWithNum, 1) - // Probe side is not coalesced. - assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) + // Probe side is coalesced. + assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).nonEmpty) // Repartition with partition non-default num specified. val dfRepartitionWithNum2 = df.repartition(3, 'b) @@ -1575,17 +1575,16 @@ class AdaptiveQueryExecSuite val dfRepartitionWithNum = df.repartition(5, 'b) dfRepartitionWithNum.collect() val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(planWithNum)) + // The top shuffle from repartition is not optimized out. + assert(hasRepartitionShuffle(planWithNum)) val smjWithNum = findTopLevelSortMergeJoin(planWithNum) assert(smjWithNum.length == 1) - // No skew join due to the repartition. - assert(!smjWithNum.head.isSkewJoin) - // No coalesce due to the num in repartition. + // Skew join can apply as the repartition is not optimized out. + assert(smjWithNum.head.isSkewJoin) val customReadersWithNum = collect(smjWithNum.head) { case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c } - assert(customReadersWithNum.isEmpty) + assert(customReadersWithNum.nonEmpty) // Repartition with default non-partition num specified. val dfRepartitionWithNum2 = df.repartition(3, 'b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 0fd9f14299..8f7616ccb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class EnsureRequirementsSuite extends SharedSparkSession { +class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { private val exprA = Literal(1) private val exprB = Literal(2) private val exprC = Literal(3) @@ -133,4 +134,26 @@ class EnsureRequirementsSuite extends SharedSparkSession { }.size == 2) } } + + test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") { + import testImplicits._ + Seq(true, false).foreach { enableAqe => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = Seq((1, 2)).toDF("c1", "c2") + val df2 = Seq((1, 3)).toDF("c3", "c4") + val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1") + val num = collect(res.queryExecution.executedPlan) { + case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM => + shuffle + }.size + if (enableAqe) { + assert(num == 1) + } else { + assert(num == 0) + } + } + } + } }