[SPARK-36184][SQL] Use ValidateRequirements instead of EnsureRequirements to skip AQE rules that adds extra shuffles
### What changes were proposed in this pull request?
Currently, two AQE rules `OptimizeLocalShuffleReader` and `OptimizeSkewedJoin` run `EnsureRequirements` at the end to check if there are extra shuffles in the optimized plan and revert the optimization if extra shuffles are introduced.
This PR proposes to run `ValidateRequirements` instead, which is much simpler than `EnsureRequirements`. This PR also moves this check to `AdaptiveSparkPlanExec`, so that it's centralized instead of in each rule. After centralization, the batch name of optimizing the final stage is the same as normal stages, which makes more sense.
### Why are the changes needed?
`EnsureRequirements` is a big rule and even contains optimizations (remove unnecessary shuffles). `ValidateRequirements` is much faster to run and can avoid potential bugs as it has no optimization and is a pure check.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests.
Closes #33396 from cloud-fan/aqe.
Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 8396a70ddc
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
c3a23ce49b
commit
5b98ec2527
|
@ -130,6 +130,27 @@ case class AdaptiveSparkPlanExec(
|
|||
}
|
||||
}
|
||||
|
||||
private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
|
||||
val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
|
||||
val applied = rule.apply(latestPlan)
|
||||
val result = rule match {
|
||||
case c: CustomShuffleReaderRule if c.mayAddExtraShuffles =>
|
||||
if (ValidateRequirements.validate(applied)) {
|
||||
applied
|
||||
} else {
|
||||
logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " +
|
||||
"will be introduced.")
|
||||
latestPlan
|
||||
}
|
||||
case _ => applied
|
||||
}
|
||||
planChangeLogger.logRule(rule.ruleName, latestPlan, result)
|
||||
result
|
||||
}
|
||||
planChangeLogger.logBatch("AQE Query Stage Optimization", plan, optimized)
|
||||
optimized
|
||||
}
|
||||
|
||||
@transient private val costEvaluator =
|
||||
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
|
||||
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
|
||||
|
@ -274,10 +295,7 @@ case class AdaptiveSparkPlanExec(
|
|||
}
|
||||
|
||||
// Run the final plan when there's no more unfinished stages.
|
||||
currentPhysicalPlan = applyPhysicalRules(
|
||||
result.newPlan,
|
||||
finalStageOptimizerRules,
|
||||
Some((planChangeLogger, "AQE Final Query Stage Optimization")))
|
||||
currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules)
|
||||
isFinalPlan = true
|
||||
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
|
||||
currentPhysicalPlan
|
||||
|
@ -480,8 +498,7 @@ case class AdaptiveSparkPlanExec(
|
|||
}
|
||||
|
||||
private def newQueryStage(e: Exchange): QueryStageExec = {
|
||||
val optimizedPlan = applyPhysicalRules(
|
||||
e.child, queryStageOptimizerRules, Some((planChangeLogger, "AQE Query Stage Optimization")))
|
||||
val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules)
|
||||
val queryStage = e match {
|
||||
case s: ShuffleExchangeLike =>
|
||||
val newShuffle = applyPhysicalRules(
|
||||
|
|
|
@ -30,4 +30,6 @@ trait CustomShuffleReaderRule extends Rule[SparkPlan] {
|
|||
* Returns the list of [[ShuffleOrigin]]s supported by this rule.
|
||||
*/
|
||||
def supportedShuffleOrigins: Seq[ShuffleOrigin]
|
||||
|
||||
def mayAddExtraShuffles: Boolean = false
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive
|
|||
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, REBALANCE_PARTITIONS_BY_NONE, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
|
||||
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE, ShuffleExchangeLike, ShuffleOrigin}
|
||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
|
@ -38,12 +38,12 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
|
|||
override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
|
||||
Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE)
|
||||
|
||||
private val ensureRequirements = EnsureRequirements
|
||||
override def mayAddExtraShuffles: Boolean = true
|
||||
|
||||
// The build side is a broadcast query stage which should have been optimized using local reader
|
||||
// already. So we only need to deal with probe side here.
|
||||
private def createProbeSideLocalReader(plan: SparkPlan): SparkPlan = {
|
||||
val optimizedPlan = plan.transformDown {
|
||||
plan.transformDown {
|
||||
case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) =>
|
||||
val localReader = createLocalReader(shuffleStage)
|
||||
join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
|
||||
|
@ -51,19 +51,6 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
|
|||
val localReader = createLocalReader(shuffleStage)
|
||||
join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
|
||||
}
|
||||
|
||||
val numShuffles = ensureRequirements.apply(optimizedPlan).collect {
|
||||
case e: ShuffleExchangeExec => e
|
||||
}.length
|
||||
|
||||
// Check whether additional shuffle introduced. If introduced, revert the local reader.
|
||||
if (numShuffles > 0) {
|
||||
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
|
||||
" to additional shuffles will be introduced.")
|
||||
plan
|
||||
} else {
|
||||
optimizedPlan
|
||||
}
|
||||
}
|
||||
|
||||
private def createLocalReader(plan: SparkPlan): CustomShuffleReaderExec = {
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils
|
|||
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleOrigin}
|
||||
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin}
|
||||
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
|
@ -52,7 +52,7 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
|
||||
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
|
||||
|
||||
private val ensureRequirements = EnsureRequirements
|
||||
override def mayAddExtraShuffles: Boolean = true
|
||||
|
||||
/**
|
||||
* A partition is considered as a skewed partition if its size is larger than the median
|
||||
|
@ -248,18 +248,11 @@ object OptimizeSkewedJoin extends CustomShuffleReaderRule {
|
|||
// Shuffle
|
||||
// Sort
|
||||
// Shuffle
|
||||
val optimizePlan = optimizeSkewJoin(plan)
|
||||
val numShuffles = ensureRequirements.apply(optimizePlan).collect {
|
||||
case e: ShuffleExchangeExec => e
|
||||
}.length
|
||||
|
||||
if (numShuffles > 0) {
|
||||
logDebug("OptimizeSkewedJoin rule is not applied due" +
|
||||
" to additional shuffles will be introduced.")
|
||||
plan
|
||||
} else {
|
||||
optimizePlan
|
||||
}
|
||||
// Or
|
||||
// SHJ
|
||||
// Shuffle
|
||||
// Shuffle
|
||||
optimizeSkewJoin(plan)
|
||||
} else {
|
||||
plan
|
||||
}
|
||||
|
|
|
@ -1425,8 +1425,7 @@ class AdaptiveQueryExecSuite
|
|||
Seq("=== Result of Batch AQE Preparations ===",
|
||||
"=== Result of Batch AQE Post Stage Creation ===",
|
||||
"=== Result of Batch AQE Replanning ===",
|
||||
"=== Result of Batch AQE Query Stage Optimization ===",
|
||||
"=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg =>
|
||||
"=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg =>
|
||||
assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg)))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue