From 5b98ec252799225397ccf0cb805dc76c588b90e2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 19 Jul 2021 14:14:40 +0800 Subject: [PATCH] [SPARK-36184][SQL] Use ValidateRequirements instead of EnsureRequirements to skip AQE rules that adds extra shuffles ### What changes were proposed in this pull request? Currently, two AQE rules `OptimizeLocalShuffleReader` and `OptimizeSkewedJoin` run `EnsureRequirements` at the end to check if there are extra shuffles in the optimized plan and revert the optimization if extra shuffles are introduced. This PR proposes to run `ValidateRequirements` instead, which is much simpler than `EnsureRequirements`. This PR also moves this check to `AdaptiveSparkPlanExec`, so that it's centralized instead of in each rule. After centralization, the batch name of optimizing the final stage is the same as normal stages, which makes more sense. ### Why are the changes needed? `EnsureRequirements` is a big rule and even contains optimizations (remove unnecessary shuffles). `ValidateRequirements` is much faster to run and can avoid potential bugs as it has no optimization and is a pure check. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests. Closes #33396 from cloud-fan/aqe. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit 8396a70ddc8e5ce2671ccb59ee4b5136ba85f85e) Signed-off-by: Wenchen Fan --- .../adaptive/AdaptiveSparkPlanExec.scala | 29 +++++++++++++++---- .../adaptive/CustomShuffleReaderRule.scala | 2 ++ .../adaptive/OptimizeLocalShuffleReader.scala | 19 ++---------- .../adaptive/OptimizeSkewedJoin.scala | 21 +++++--------- .../adaptive/AdaptiveQueryExecSuite.scala | 3 +- 5 files changed, 36 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 13b65aa67b..93beef8b7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -130,6 +130,27 @@ case class AdaptiveSparkPlanExec( } } + private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = { + val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => + val applied = rule.apply(latestPlan) + val result = rule match { + case c: CustomShuffleReaderRule if c.mayAddExtraShuffles => + if (ValidateRequirements.validate(applied)) { + applied + } else { + logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " + + "will be introduced.") + latestPlan + } + case _ => applied + } + planChangeLogger.logRule(rule.ruleName, latestPlan, result) + result + } + planChangeLogger.logBatch("AQE Query Stage Optimization", plan, optimized) + optimized + } + @transient private val costEvaluator = conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match { case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf) @@ -274,10 +295,7 @@ case class AdaptiveSparkPlanExec( } // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules( - result.newPlan, - finalStageOptimizerRules, - Some((planChangeLogger, "AQE Final Query Stage Optimization"))) + currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules) isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -480,8 +498,7 @@ case class AdaptiveSparkPlanExec( } private def newQueryStage(e: Exchange): QueryStageExec = { - val optimizedPlan = applyPhysicalRules( - e.child, queryStageOptimizerRules, Some((planChangeLogger, "AQE Query Stage Optimization"))) + val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules) val queryStage = e match { case s: ShuffleExchangeLike => val newShuffle = applyPhysicalRules( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderRule.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderRule.scala index c5b8f73ea5..3004a3d689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderRule.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderRule.scala @@ -30,4 +30,6 @@ trait CustomShuffleReaderRule extends Rule[SparkPlan] { * Returns the list of [[ShuffleOrigin]]s supported by this rule. */ def supportedShuffleOrigins: Seq[ShuffleOrigin] + + def mayAddExtraShuffles: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index b17af00053..c91b999500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, REBALANCE_PARTITIONS_BY_NONE, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf @@ -38,12 +38,12 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule { override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE) - private val ensureRequirements = EnsureRequirements + override def mayAddExtraShuffles: Boolean = true // The build side is a broadcast query stage which should have been optimized using local reader // already. So we only need to deal with probe side here. private def createProbeSideLocalReader(plan: SparkPlan): SparkPlan = { - val optimizedPlan = plan.transformDown { + plan.transformDown { case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) => val localReader = createLocalReader(shuffleStage) join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader) @@ -51,19 +51,6 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule { val localReader = createLocalReader(shuffleStage) join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader) } - - val numShuffles = ensureRequirements.apply(optimizedPlan).collect { - case e: ShuffleExchangeExec => e - }.length - - // Check whether additional shuffle introduced. If introduced, revert the local reader. - if (numShuffles > 0) { - logDebug("OptimizeLocalShuffleReader rule is not applied due" + - " to additional shuffles will be introduced.") - plan - } else { - optimizedPlan - } } private def createLocalReader(plan: SparkPlan): CustomShuffleReaderExec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index a284016bfb..810084a65f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf @@ -52,7 +52,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule { override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS) - private val ensureRequirements = EnsureRequirements + override def mayAddExtraShuffles: Boolean = true /** * A partition is considered as a skewed partition if its size is larger than the median @@ -248,18 +248,11 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule { // Shuffle // Sort // Shuffle - val optimizePlan = optimizeSkewJoin(plan) - val numShuffles = ensureRequirements.apply(optimizePlan).collect { - case e: ShuffleExchangeExec => e - }.length - - if (numShuffles > 0) { - logDebug("OptimizeSkewedJoin rule is not applied due" + - " to additional shuffles will be introduced.") - plan - } else { - optimizePlan - } + // Or + // SHJ + // Shuffle + // Shuffle + optimizeSkewJoin(plan) } else { plan } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 13bba68ff6..b56e67300d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1425,8 +1425,7 @@ class AdaptiveQueryExecSuite Seq("=== Result of Batch AQE Preparations ===", "=== Result of Batch AQE Post Stage Creation ===", "=== Result of Batch AQE Replanning ===", - "=== Result of Batch AQE Query Stage Optimization ===", - "=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg => + "=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg => assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg))) } }