[SPARK-36184][SQL] Use ValidateRequirements instead of EnsureRequirements to skip AQE rules that adds extra shuffles

### What changes were proposed in this pull request?

Currently, two AQE rules `OptimizeLocalShuffleReader` and `OptimizeSkewedJoin` run `EnsureRequirements` at the end to check if there are extra shuffles in the optimized plan and revert the optimization if extra shuffles are introduced.

This PR proposes to run `ValidateRequirements` instead, which is much simpler than `EnsureRequirements`. This PR also moves this check to `AdaptiveSparkPlanExec`, so that it's centralized instead of in each rule. After centralization, the batch name of optimizing the final stage is the same as normal stages, which makes more sense.

### Why are the changes needed?

`EnsureRequirements` is a big rule and even contains optimizations (remove unnecessary shuffles). `ValidateRequirements` is much faster to run and can avoid potential bugs as it has no optimization and is a pure check.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests.

Closes #33396 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 8396a70ddc)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2021-07-19 14:14:40 +08:00
parent c3a23ce49b
commit 5b98ec2527
5 changed files with 36 additions and 38 deletions

View file

@ -130,6 +130,27 @@ case class AdaptiveSparkPlanExec(
}
}
private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case c: CustomShuffleReaderRule if c.mayAddExtraShuffles =>
if (ValidateRequirements.validate(applied)) {
applied
} else {
logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " +
"will be introduced.")
latestPlan
}
case _ => applied
}
planChangeLogger.logRule(rule.ruleName, latestPlan, result)
result
}
planChangeLogger.logBatch("AQE Query Stage Optimization", plan, optimized)
optimized
}
@transient private val costEvaluator =
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
@ -274,10 +295,7 @@ case class AdaptiveSparkPlanExec(
}
// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(
result.newPlan,
finalStageOptimizerRules,
Some((planChangeLogger, "AQE Final Query Stage Optimization")))
currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules)
isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
@ -480,8 +498,7 @@ case class AdaptiveSparkPlanExec(
}
private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = applyPhysicalRules(
e.child, queryStageOptimizerRules, Some((planChangeLogger, "AQE Query Stage Optimization")))
val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules)
val queryStage = e match {
case s: ShuffleExchangeLike =>
val newShuffle = applyPhysicalRules(

View file

@ -30,4 +30,6 @@ trait CustomShuffleReaderRule extends Rule[SparkPlan] {
* Returns the list of [[ShuffleOrigin]]s supported by this rule.
*/
def supportedShuffleOrigins: Seq[ShuffleOrigin]
def mayAddExtraShuffles: Boolean = false
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, REBALANCE_PARTITIONS_BY_NONE, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.internal.SQLConf
@ -38,12 +38,12 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE)
private val ensureRequirements = EnsureRequirements
override def mayAddExtraShuffles: Boolean = true
// The build side is a broadcast query stage which should have been optimized using local reader
// already. So we only need to deal with probe side here.
private def createProbeSideLocalReader(plan: SparkPlan): SparkPlan = {
val optimizedPlan = plan.transformDown {
plan.transformDown {
case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) =>
val localReader = createLocalReader(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
@ -51,19 +51,6 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
val localReader = createLocalReader(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
}
val numShuffles = ensureRequirements.apply(optimizedPlan).collect {
case e: ShuffleExchangeExec => e
}.length
// Check whether additional shuffle introduced. If introduced, revert the local reader.
if (numShuffles > 0) {
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
" to additional shuffles will be introduced.")
plan
} else {
optimizedPlan
}
}
private def createLocalReader(plan: SparkPlan): CustomShuffleReaderExec = {

View file

@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
@ -52,7 +52,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
private val ensureRequirements = EnsureRequirements
override def mayAddExtraShuffles: Boolean = true
/**
* A partition is considered as a skewed partition if its size is larger than the median
@ -248,18 +248,11 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
// Shuffle
// Sort
// Shuffle
val optimizePlan = optimizeSkewJoin(plan)
val numShuffles = ensureRequirements.apply(optimizePlan).collect {
case e: ShuffleExchangeExec => e
}.length
if (numShuffles > 0) {
logDebug("OptimizeSkewedJoin rule is not applied due" +
" to additional shuffles will be introduced.")
plan
} else {
optimizePlan
}
// Or
// SHJ
// Shuffle
// Shuffle
optimizeSkewJoin(plan)
} else {
plan
}

View file

@ -1425,8 +1425,7 @@ class AdaptiveQueryExecSuite
Seq("=== Result of Batch AQE Preparations ===",
"=== Result of Batch AQE Post Stage Creation ===",
"=== Result of Batch AQE Replanning ===",
"=== Result of Batch AQE Query Stage Optimization ===",
"=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg =>
"=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg =>
assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg)))
}
}