From dd80457ffb1c129a1ca3c53bcf3ea5feed7ebc57 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Aug 2021 18:28:52 +0800 Subject: [PATCH] [SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/30494 This PR proposes a new way to optimize the final query stage in AQE. We first collect the effective user-specified repartition (semantic-wise, user-specified repartition is only effective if it's the root node or under a few simple nodes), and get the required distribution for the final plan. When we optimize the final query stage, we skip certain `AQEShuffleReadRule` if it breaks the required distribution. ### Why are the changes needed? The current solution for optimizing the final query stage is pretty hacky and overkill. As an example, the newly added rule `OptimizeSkewInRebalancePartitions` can hardly apply as it's very common that the query plan has shuffles with origin `ENSURE_REQUIREMENTS`, which is not supported by `OptimizeSkewInRebalancePartitions`. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? updated tests Closes #33541 from cloud-fan/aqe. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/execution/QueryExecution.scala | 2 +- .../adaptive/AQEShuffleReadExec.scala | 36 +++- .../adaptive/AQEShuffleReadRule.scala | 12 +- .../sql/execution/adaptive/AQEUtils.scala | 60 +++++++ .../adaptive/AdaptiveSparkPlanExec.scala | 63 ++++--- .../adaptive/CoalesceShufflePartitions.scala | 10 +- .../OptimizeShuffleWithLocalRead.scala | 12 +- .../OptimizeSkewInRebalancePartitions.scala | 8 +- .../adaptive/OptimizeSkewedJoin.scala | 15 +- .../exchange/EnsureRequirements.scala | 16 +- .../exchange/ValidateRequirements.scala | 4 + .../spark/sql/execution/PlannerSuite.scala | 2 + .../adaptive/AdaptiveQueryExecSuite.scala | 162 ++++++++++-------- .../exchange/EnsureRequirementsSuite.scala | 27 +-- .../execution/joins/BroadcastJoinSuite.scala | 2 + .../execution/joins/ExistenceJoinSuite.scala | 2 + .../sql/execution/joins/InnerJoinSuite.scala | 2 + .../sql/execution/joins/OuterJoinSuite.scala | 2 + 18 files changed, 270 insertions(+), 167 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5a654add31..6c16dceae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -409,7 +409,7 @@ object QueryExecution { PlanDynamicPruningFilters(sparkSession), PlanSubqueries(sparkSession), RemoveRedundantProjects, - EnsureRequirements, + EnsureRequirements(), // `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same // number of partitions when instantiating PartitioningCollection. RemoveRedundantSorts, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala index 0768b9b3d6..af62157fe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala @@ -22,13 +22,13 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch - /** * A wrapper of shuffle query stage, which follows the given partition arrangement. * @@ -51,6 +51,7 @@ case class AQEShuffleReadExec private( override def supportsColumnar: Boolean = child.supportsColumnar override def output: Seq[Attribute] = child.output + override lazy val outputPartitioning: Partitioning = { // If it is a local shuffle read with one mapper per task, then the output partitioning is // the same as the plan before shuffle. @@ -69,6 +70,21 @@ case class AQEShuffleReadExec private( case _ => throw new IllegalStateException("operating on canonicalization plan") } + } else if (isCoalescedRead) { + // For coalesced shuffle read, the data distribution is not changed, only the number of + // partitions is changed. + child.outputPartitioning match { + case h: HashPartitioning => + CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length)) + case r: RangePartitioning => + CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length)) + // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses + // `RoundRobinPartitioning` but we don't need to retain the number of partitions. + case r: RoundRobinPartitioning => + r.copy(numPartitions = partitionSpecs.length) + case other => throw new IllegalStateException( + "Unexpected partitioning for coalesced shuffle read: " + other) + } } else { UnknownPartitioning(partitionSpecs.length) } @@ -92,7 +108,7 @@ case class AQEShuffleReadExec private( /** * Returns true iff some partitions were actually combined */ - private def isCoalesced(spec: ShufflePartitionSpec) = spec match { + private def isCoalescedSpec(spec: ShufflePartitionSpec) = spec match { case CoalescedPartitionSpec(0, 0, _) => true case s: CoalescedPartitionSpec => s.endReducerIndex - s.startReducerIndex > 1 case _ => false @@ -102,7 +118,7 @@ case class AQEShuffleReadExec private( * Returns true iff some non-empty partitions were combined */ def hasCoalescedPartition: Boolean = { - partitionSpecs.exists(isCoalesced) + partitionSpecs.exists(isCoalescedSpec) } def hasSkewedPartition: Boolean = @@ -112,6 +128,16 @@ case class AQEShuffleReadExec private( partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) || partitionSpecs.exists(_.isInstanceOf[CoalescedMapperPartitionSpec]) + def isCoalescedRead: Boolean = { + partitionSpecs.sliding(2).forall { + // A single partition spec which is `CoalescedPartitionSpec` also means coalesced read. + case Seq(_: CoalescedPartitionSpec) => true + case Seq(l: CoalescedPartitionSpec, r: CoalescedPartitionSpec) => + l.endReducerIndex <= r.startReducerIndex + case _ => false + } + } + private def shuffleStage = child match { case stage: ShuffleQueryStageExec => Some(stage) case _ => None @@ -159,7 +185,7 @@ case class AQEShuffleReadExec private( if (hasCoalescedPartition) { val numCoalescedPartitionsMetric = metrics("numCoalescedPartitions") - val x = partitionSpecs.count(isCoalesced) + val x = partitionSpecs.count(isCoalescedSpec) numCoalescedPartitionsMetric.set(x) driverAccumUpdates += numCoalescedPartitionsMetric.id -> x } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala index 1c7f2eacda..c303e85643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadRule.scala @@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.ShuffleOrigin +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin} /** - * Adaptive Query Execution rule that may create [[AQEShuffleReadExec]] on top of query stages. + * A rule that may create [[AQEShuffleReadExec]] on top of [[ShuffleQueryStageExec]] and change the + * plan output partitioning. The AQE framework will skip the rule if it leads to extra shuffles. */ trait AQEShuffleReadRule extends Rule[SparkPlan] { - /** * Returns the list of [[ShuffleOrigin]]s supported by this rule. */ - def supportedShuffleOrigins: Seq[ShuffleOrigin] + protected def supportedShuffleOrigins: Seq[ShuffleOrigin] - def mayAddExtraShuffles: Boolean = false + protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = { + supportedShuffleOrigins.contains(shuffle.shuffleOrigin) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala new file mode 100644 index 0000000000..277af212d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, HashPartitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{CollectMetricsExec, FilterExec, ProjectExec, SortExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec} + +object AQEUtils { + + // Analyze the given plan and calculate the required distribution of this plan w.r.t. the + // user-specified repartition. + def getRequiredDistribution(p: SparkPlan): Option[Distribution] = p match { + // User-specified repartition is only effective when it's the root node, or under + // Project/Filter/LocalSort/CollectMetrics. + // Note: we only care about `HashPartitioning` as `EnsureRequirements` can only optimize out + // user-specified repartition with `HashPartitioning`. + case ShuffleExchangeExec(h: HashPartitioning, _, shuffleOrigin) + if shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM => + val numPartitions = if (shuffleOrigin == REPARTITION_BY_NUM) { + Some(h.numPartitions) + } else { + None + } + Some(HashClusteredDistribution(h.expressions, numPartitions)) + case f: FilterExec => getRequiredDistribution(f.child) + case s: SortExec if !s.global => getRequiredDistribution(s.child) + case c: CollectMetricsExec => getRequiredDistribution(c.child) + case p: ProjectExec => + getRequiredDistribution(p.child).flatMap { + case h: HashClusteredDistribution => + if (h.expressions.forall(e => p.projectList.exists(_.semanticEquals(e)))) { + Some(h) + } else { + // It's possible that the user-specified repartition is effective but the output + // partitioning is not retained, e.g. `df.repartition(a, b).select(c)`. We can't + // handle this case with required distribution. Here we return None and later on + // `EnsureRequirements` will skip optimizing out the user-specified repartition. + None + } + case other => Some(other) + } + case _ => Some(UnspecifiedDistribution) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index c03bb4b50d..9db457425c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryExecutionErrors @@ -83,12 +84,28 @@ case class AdaptiveSparkPlanExec( // The logical plan optimizer for re-optimizing the current logical plan. @transient private val optimizer = new AQEOptimizer(conf) + // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't + // change its output partitioning. This assumption is not true in AQE. Here we check the + // `inputPlan` which has not been processed by `EnsureRequirements` yet, to find out the + // effective user-specified repartition. Later on, the AQE framework will make sure the final + // output partitioning is not changed w.r.t the effective user-specified repartition. + @transient private val requiredDistribution: Option[Distribution] = if (isSubquery) { + // Subquery output does not need a specific output partitioning. + Some(UnspecifiedDistribution) + } else { + AQEUtils.getRequiredDistribution(inputPlan) + } + // A list of physical plan rules to be applied before creation of query stages. The physical // plan should reach a final status of query stages (i.e., no more addition or removal of // Exchange nodes) after running these rules. private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( RemoveRedundantProjects, - EnsureRequirements, + // For cases like `df.repartition(a, b).select(c)`, there is no distribution requirement for + // the final plan, but we do need to respect the user-specified repartition. Here we ask + // `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work + // around this case. + EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined), RemoveRedundantSorts, DisableUnnecessaryBucketedScan ) ++ context.session.sessionState.queryStagePrepRules @@ -114,33 +131,24 @@ case class AdaptiveSparkPlanExec( CollapseCodegenStages() ) ++ context.session.sessionState.postStageCreationRules - // The partitioning of the query output depends on the shuffle(s) in the final stage. If the - // original plan contains a repartition operator, we need to preserve the specified partitioning, - // whether or not the repartition-introduced shuffle is optimized out because of an underlying - // shuffle of the same partitioning. Thus, we need to exclude some `AQEShuffleReadRule`s - // from the final stage, depending on the presence and properties of repartition operators. - private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] = { - val origins = inputPlan.collect { - case s: ShuffleExchangeLike => s.shuffleOrigin - } - val allRules = queryStageOptimizerRules ++ postStageCreationRules - allRules.filter { - case c: AQEShuffleReadRule => - origins.forall(c.supportedShuffleOrigins.contains) - case _ => true - } - } - - private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = { - val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => + private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { + val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { - case c: AQEShuffleReadRule if c.mayAddExtraShuffles => - if (ValidateRequirements.validate(applied)) { + case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) => + val distribution = if (isFinalStage) { + // If `requiredDistribution` is None, it means `EnsureRequirements` will not optimize + // out the user-specified repartition, thus we don't have a distribution requirement + // for the final plan. + requiredDistribution.getOrElse(UnspecifiedDistribution) + } else { + UnspecifiedDistribution + } + if (ValidateRequirements.validate(applied, distribution)) { applied } else { - logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " + - "will be introduced.") + logDebug(s"Rule ${rule.ruleName} is not applied as it breaks the " + + "distribution requirement of the query plan.") latestPlan } case _ => applied @@ -303,7 +311,10 @@ case class AdaptiveSparkPlanExec( } // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules) + currentPhysicalPlan = applyPhysicalRules( + optimizeQueryStage(result.newPlan, isFinalStage = true), + postStageCreationRules, + Some((planChangeLogger, "AQE Post Stage Creation"))) isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -520,7 +531,7 @@ case class AdaptiveSparkPlanExec( } private def newQueryStage(e: Exchange): QueryStageExec = { - val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules) + val optimizedPlan = optimizeQueryStage(e.child, isFinalStage = false) val queryStage = e match { case s: ShuffleExchangeLike => val newShuffle = applyPhysicalRules( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 7f3e453f90..75c53b4f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -33,6 +33,10 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe Seq(ENSURE_REQUIREMENTS, REPARTITION_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REBALANCE_PARTITIONS_BY_COL) + override def isSupported(shuffle: ShuffleExchangeLike): Boolean = { + shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle) + } + override def apply(plan: SparkPlan): SparkPlan = { if (!conf.coalesceShufflePartitionsEnabled) { return plan @@ -52,7 +56,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe val shuffleStageInfos = collectShuffleStageInfos(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. - if (!shuffleStageInfos.forall(s => supportCoalesce(s.shuffleStage.shuffle))) { + if (!shuffleStageInfos.forall(s => isSupported(s.shuffleStage.shuffle))) { plan } else { // Ideally, this rule should simply coalesce partition w.r.t. the target size specified by @@ -106,10 +110,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe }.getOrElse(plan) case other => other.mapChildren(updateShuffleReads(_, specsMap)) } - - private def supportCoalesce(s: ShuffleExchangeLike): Boolean = { - s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin) - } } private class ShuffleStageInfo( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala index 844acbd1a2..cf1c7ecedd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala @@ -38,7 +38,9 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule { override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE) - override def mayAddExtraShuffles: Boolean = true + override protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = { + shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle) + } // The build side is a broadcast query stage which should have been optimized using local read // already. So we only need to deal with probe side here. @@ -136,14 +138,10 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule { def canUseLocalShuffleRead(plan: SparkPlan): Boolean = plan match { case s: ShuffleQueryStageExec => - s.mapStats.isDefined && supportLocalRead(s.shuffle) + s.mapStats.isDefined && isSupported(s.shuffle) case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) => - s.mapStats.isDefined && supportLocalRead(s.shuffle) && + s.mapStats.isDefined && isSupported(s.shuffle) && s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS case _ => false } - - private def supportLocalRead(s: ShuffleExchangeLike): Boolean = { - s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala index dc43740335..1752907a9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala @@ -38,7 +38,8 @@ import org.apache.spark.sql.internal.SQLConf * ShuffleQueryStageExec. */ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { - override def supportedShuffleOrigins: Seq[ShuffleOrigin] = + + override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(REBALANCE_PARTITIONS_BY_NONE, REBALANCE_PARTITIONS_BY_COL) /** @@ -92,9 +93,8 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { } plan match { - case shuffle: ShuffleQueryStageExec - if supportedShuffleOrigins.contains(shuffle.shuffle.shuffleOrigin) => - tryOptimizeSkewedPartitions(shuffle) + case stage: ShuffleQueryStageExec if isSupported(stage.shuffle) => + tryOptimizeSkewedPartitions(stage) case _ => plan } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index fbfbce6233..88abe68197 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -52,8 +52,6 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule { override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS) - override def mayAddExtraShuffles: Boolean = true - /** * A partition is considered as a skewed partition if its size is larger than the median * partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than @@ -257,13 +255,12 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule { plan } } -} -private object ShuffleStage { - def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match { - case s: ShuffleQueryStageExec if s.mapStats.isDefined && - OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) => - Some(s) - case _ => None + object ShuffleStage { + def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match { + case s: ShuffleQueryStageExec if s.mapStats.isDefined && isSupported(s.shuffle) => + Some(s) + case _ => None + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d71933ab58..23716f1081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -32,8 +32,14 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for * each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that * the input partition ordering requirements are met. + * + * @param optimizeOutRepartition A flag to indicate that if this rule should optimize out + * user-specified repartition shuffles or not. This is mostly true, + * but can be false in AQE when AQE optimization may change the plan + * output partitioning and need to retain the user-specified + * repartition shuffles in the plan. */ -object EnsureRequirements extends Rule[SparkPlan] { +case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Rule[SparkPlan] { private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution @@ -249,13 +255,9 @@ object EnsureRequirements extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - // SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM - // shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL, - // it is a special shuffle used to rebalance partitions. - // So, here we only remove REPARTITION_BY_COL in AQE. case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin) - if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled => + if optimizeOutRepartition && + (shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) => def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = { partitioning match { case lower: HashPartitioning if upper.semanticEquals(lower) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala index 6964d9c9dd..5003db6a16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala @@ -30,6 +30,10 @@ import org.apache.spark.sql.execution._ */ object ValidateRequirements extends Logging { + def validate(plan: SparkPlan, requiredDistribution: Distribution): Boolean = { + validate(plan) && plan.outputPartitioning.satisfies(requiredDistribution) + } + def validate(plan: SparkPlan): Boolean = { plan.children.forall(validate) && validateInternal(plan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index fad6ed104f..df310cbaee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -40,6 +40,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { setupTestData() + private val EnsureRequirements = new EnsureRequirements() + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = spark.sessionState.planner import planner._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index dda94f18c0..ca8295e163 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1496,6 +1496,60 @@ class AdaptiveQueryExecSuite }.isDefined } + def checkBHJ( + df: Dataset[Row], + optimizeOutRepartition: Boolean, + probeSideLocalRead: Boolean, + probeSideCoalescedRead: Boolean): Unit = { + df.collect() + val plan = df.queryExecution.executedPlan + // There should be only one shuffle that can't do local read, which is either the top shuffle + // from repartition, or BHJ probe side shuffle. + checkNumLocalShuffleReads(plan, 1) + assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition) + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.length == 1) + + // Build side should do local read. + val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec]) + assert(buildSide.isDefined) + assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) + + val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec]) + if (probeSideLocalRead || probeSideCoalescedRead) { + assert(probeSide.isDefined) + if (probeSideLocalRead) { + assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) + } else { + assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition) + } + } else { + assert(probeSide.isEmpty) + } + } + + def checkSMJ( + df: Dataset[Row], + optimizeOutRepartition: Boolean, + optimizeSkewJoin: Boolean, + coalescedRead: Boolean): Unit = { + df.collect() + val plan = df.queryExecution.executedPlan + assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.length == 1) + assert(smj.head.isSkewJoin == optimizeSkewJoin) + val aqeReads = collect(smj.head) { + case c: AQEShuffleReadExec => c + } + if (coalescedRead || optimizeSkewJoin) { + assert(aqeReads.length == 2) + if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition)) + } else { + assert(aqeReads.isEmpty) + } + } + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.SHUFFLE_PARTITIONS.key -> "5") { val df = sql( @@ -1509,44 +1563,25 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(plan)) - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.length == 1) - checkNumLocalShuffleReads(plan, 1) - // Probe side is coalesced. - val aqeRead = bhj.head.right.find(_.isInstanceOf[AQEShuffleReadExec]) - assert(aqeRead.isDefined) - assert(aqeRead.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition) + checkBHJ(df.repartition('b), + // The top shuffle from repartition is optimized out. + optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true) - // Repartition with partition default num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out. - assert(hasRepartitionShuffle(planWithNum)) - val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) - assert(bhjWithNum.length == 1) - checkNumLocalShuffleReads(planWithNum, 1) - // Probe side is coalesced. - assert(bhjWithNum.head.right.find(_.isInstanceOf[AQEShuffleReadExec]).nonEmpty) + // Repartition with default partition num (5 in test env) specified. + checkBHJ(df.repartition(5, 'b), + // The top shuffle from repartition is optimized out + // The final plan must have 5 partitions, no optimization can be made to the probe side. + optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false) - // Repartition with partition non-default num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out, and this is the only shuffle that - // does not have local shuffle read. - assert(hasRepartitionShuffle(planWithNum2)) - val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) - assert(bhjWithNum2.length == 1) - checkNumLocalShuffleReads(planWithNum2, 1) - val aqeRead2 = bhjWithNum2.head.right.find(_.isInstanceOf[AQEShuffleReadExec]) - assert(aqeRead2.isDefined) - assert(aqeRead2.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) + // Repartition with non-default partition num specified. + checkBHJ(df.repartition(4, 'b), + // The top shuffle from repartition is not optimized out + optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) + + // Repartition by col and project away the partition cols + checkBHJ(df.repartition('b).select('key), + // The top shuffle from repartition is not optimized out + optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) } // Force skew join @@ -1556,46 +1591,25 @@ class AdaptiveQueryExecSuite SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(plan)) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.length == 1) - // No skew join due to the repartition. - assert(!smj.head.isSkewJoin) - // Both sides are coalesced. - val aqeReads = collect(smj.head) { - case c: AQEShuffleReadExec if c.hasCoalescedPartition => c - } - assert(aqeReads.length == 2) + checkSMJ(df.repartition('b), + // The top shuffle from repartition is optimized out. + optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true) - // Repartition with default partition num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out. - assert(hasRepartitionShuffle(planWithNum)) - val smjWithNum = findTopLevelSortMergeJoin(planWithNum) - assert(smjWithNum.length == 1) - // Skew join can apply as the repartition is not optimized out. - assert(smjWithNum.head.isSkewJoin) - val aqeReadsWithNum = collect(smjWithNum.head) { - case c: AQEShuffleReadExec => c - } - assert(aqeReadsWithNum.nonEmpty) + // Repartition with default partition num (5 in test env) specified. + checkSMJ(df.repartition(5, 'b), + // The top shuffle from repartition is optimized out. + // The final plan must have 5 partitions, can't do coalesced read. + optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false) - // Repartition with default non-partition num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out. - assert(hasRepartitionShuffle(planWithNum2)) - val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2) - assert(smjWithNum2.length == 1) - // Skew join can apply as the repartition is not optimized out. - assert(smjWithNum2.head.isSkewJoin) + // Repartition with non-default partition num specified. + checkSMJ(df.repartition(4, 'b), + // The top shuffle from repartition is not optimized out. + optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) + + // Repartition by col and project away the partition cols + checkSMJ(df.repartition('b).select('key), + // The top shuffle from repartition is not optimized out. + optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 8f7616ccb4..0425be6f9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -21,16 +21,17 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { +class EnsureRequirementsSuite extends SharedSparkSession { private val exprA = Literal(1) private val exprB = Literal(2) private val exprC = Literal(3) + private val EnsureRequirements = new EnsureRequirements() + test("reorder should handle PartitioningCollection") { val plan1 = DummySparkPlan( outputPartitioning = PartitioningCollection(Seq( @@ -134,26 +135,4 @@ class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanH }.size == 2) } } - - test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") { - import testImplicits._ - Seq(true, false).foreach { enableAqe => - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString, - SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = Seq((1, 2)).toDF("c1", "c2") - val df2 = Seq((1, 3)).toDF("c3", "c4") - val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1") - val num = collect(res.queryExecution.executedPlan) { - case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM => - shuffle - }.size - if (enableAqe) { - assert(num == 1) - } else { - assert(num == 0) - } - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 92c38ee228..83163cfb42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -48,6 +48,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils protected var spark: SparkSession = null + private val EnsureRequirements = new EnsureRequirements() + /** * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 3588b9dda9..71e59ad446 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructT class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession { + private val EnsureRequirements = new EnsureRequirements() + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 5262320134..653049bfdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -33,6 +33,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder + private val EnsureRequirements = new EnsureRequirements() + private lazy val myUpperCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "A"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 744ee1ca1a..f704fdb996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { + private val EnsureRequirements = new EnsureRequirements() + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0),