[SPARK-36315][SQL] Only skip AQEShuffleReadRule in the final stage if it breaks the distribution requirement

### What changes were proposed in this pull request?

This is a followup of https://github.com/apache/spark/pull/30494

This PR proposes a new way to optimize the final query stage in AQE. We first collect the effective user-specified repartition (semantic-wise, user-specified repartition is only effective if it's the root node or under a few simple nodes), and get the required distribution for the final plan. When we optimize the final query stage, we skip certain `AQEShuffleReadRule` if it breaks the required distribution.

### Why are the changes needed?

The current solution for optimizing the final query stage is pretty hacky and overkill. As an example, the newly added rule `OptimizeSkewInRebalancePartitions` can hardly apply as it's very common that the query plan has shuffles with origin `ENSURE_REQUIREMENTS`, which is not supported by `OptimizeSkewInRebalancePartitions`.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

updated tests

Closes #33541 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2021-08-03 18:28:52 +08:00
parent 1deb386727
commit dd80457ffb
18 changed files with 270 additions and 167 deletions

View file

@ -409,7 +409,7 @@ object QueryExecution {
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements,
EnsureRequirements(),
// `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same
// number of partitions when instantiating PartitioningCollection.
RemoveRedundantSorts,

View file

@ -22,13 +22,13 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* A wrapper of shuffle query stage, which follows the given partition arrangement.
*
@ -51,6 +51,7 @@ case class AQEShuffleReadExec private(
override def supportsColumnar: Boolean = child.supportsColumnar
override def output: Seq[Attribute] = child.output
override lazy val outputPartitioning: Partitioning = {
// If it is a local shuffle read with one mapper per task, then the output partitioning is
// the same as the plan before shuffle.
@ -69,6 +70,21 @@ case class AQEShuffleReadExec private(
case _ =>
throw new IllegalStateException("operating on canonicalization plan")
}
} else if (isCoalescedRead) {
// For coalesced shuffle read, the data distribution is not changed, only the number of
// partitions is changed.
child.outputPartitioning match {
case h: HashPartitioning =>
CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length))
case r: RangePartitioning =>
CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length))
// This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
// `RoundRobinPartitioning` but we don't need to retain the number of partitions.
case r: RoundRobinPartitioning =>
r.copy(numPartitions = partitionSpecs.length)
case other => throw new IllegalStateException(
"Unexpected partitioning for coalesced shuffle read: " + other)
}
} else {
UnknownPartitioning(partitionSpecs.length)
}
@ -92,7 +108,7 @@ case class AQEShuffleReadExec private(
/**
* Returns true iff some partitions were actually combined
*/
private def isCoalesced(spec: ShufflePartitionSpec) = spec match {
private def isCoalescedSpec(spec: ShufflePartitionSpec) = spec match {
case CoalescedPartitionSpec(0, 0, _) => true
case s: CoalescedPartitionSpec => s.endReducerIndex - s.startReducerIndex > 1
case _ => false
@ -102,7 +118,7 @@ case class AQEShuffleReadExec private(
* Returns true iff some non-empty partitions were combined
*/
def hasCoalescedPartition: Boolean = {
partitionSpecs.exists(isCoalesced)
partitionSpecs.exists(isCoalescedSpec)
}
def hasSkewedPartition: Boolean =
@ -112,6 +128,16 @@ case class AQEShuffleReadExec private(
partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) ||
partitionSpecs.exists(_.isInstanceOf[CoalescedMapperPartitionSpec])
def isCoalescedRead: Boolean = {
partitionSpecs.sliding(2).forall {
// A single partition spec which is `CoalescedPartitionSpec` also means coalesced read.
case Seq(_: CoalescedPartitionSpec) => true
case Seq(l: CoalescedPartitionSpec, r: CoalescedPartitionSpec) =>
l.endReducerIndex <= r.startReducerIndex
case _ => false
}
}
private def shuffleStage = child match {
case stage: ShuffleQueryStageExec => Some(stage)
case _ => None
@ -159,7 +185,7 @@ case class AQEShuffleReadExec private(
if (hasCoalescedPartition) {
val numCoalescedPartitionsMetric = metrics("numCoalescedPartitions")
val x = partitionSpecs.count(isCoalesced)
val x = partitionSpecs.count(isCoalescedSpec)
numCoalescedPartitionsMetric.set(x)
driverAccumUpdates += numCoalescedPartitionsMetric.id -> x
}

View file

@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.ShuffleOrigin
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin}
/**
* Adaptive Query Execution rule that may create [[AQEShuffleReadExec]] on top of query stages.
* A rule that may create [[AQEShuffleReadExec]] on top of [[ShuffleQueryStageExec]] and change the
* plan output partitioning. The AQE framework will skip the rule if it leads to extra shuffles.
*/
trait AQEShuffleReadRule extends Rule[SparkPlan] {
/**
* Returns the list of [[ShuffleOrigin]]s supported by this rule.
*/
def supportedShuffleOrigins: Seq[ShuffleOrigin]
protected def supportedShuffleOrigins: Seq[ShuffleOrigin]
def mayAddExtraShuffles: Boolean = false
protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
supportedShuffleOrigins.contains(shuffle.shuffleOrigin)
}
}

View file

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, HashPartitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CollectMetricsExec, FilterExec, ProjectExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
object AQEUtils {
// Analyze the given plan and calculate the required distribution of this plan w.r.t. the
// user-specified repartition.
def getRequiredDistribution(p: SparkPlan): Option[Distribution] = p match {
// User-specified repartition is only effective when it's the root node, or under
// Project/Filter/LocalSort/CollectMetrics.
// Note: we only care about `HashPartitioning` as `EnsureRequirements` can only optimize out
// user-specified repartition with `HashPartitioning`.
case ShuffleExchangeExec(h: HashPartitioning, _, shuffleOrigin)
if shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM =>
val numPartitions = if (shuffleOrigin == REPARTITION_BY_NUM) {
Some(h.numPartitions)
} else {
None
}
Some(HashClusteredDistribution(h.expressions, numPartitions))
case f: FilterExec => getRequiredDistribution(f.child)
case s: SortExec if !s.global => getRequiredDistribution(s.child)
case c: CollectMetricsExec => getRequiredDistribution(c.child)
case p: ProjectExec =>
getRequiredDistribution(p.child).flatMap {
case h: HashClusteredDistribution =>
if (h.expressions.forall(e => p.projectList.exists(_.semanticEquals(e)))) {
Some(h)
} else {
// It's possible that the user-specified repartition is effective but the output
// partitioning is not retained, e.g. `df.repartition(a, b).select(c)`. We can't
// handle this case with required distribution. Here we return None and later on
// `EnsureRequirements` will skip optimizing out the user-specified repartition.
None
}
case other => Some(other)
}
case _ => Some(UnspecifiedDistribution)
}
}

View file

@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryExecutionErrors
@ -83,12 +84,28 @@ case class AdaptiveSparkPlanExec(
// The logical plan optimizer for re-optimizing the current logical plan.
@transient private val optimizer = new AQEOptimizer(conf)
// `EnsureRequirements` may remove user-specified repartition and assume the query plan won't
// change its output partitioning. This assumption is not true in AQE. Here we check the
// `inputPlan` which has not been processed by `EnsureRequirements` yet, to find out the
// effective user-specified repartition. Later on, the AQE framework will make sure the final
// output partitioning is not changed w.r.t the effective user-specified repartition.
@transient private val requiredDistribution: Option[Distribution] = if (isSubquery) {
// Subquery output does not need a specific output partitioning.
Some(UnspecifiedDistribution)
} else {
AQEUtils.getRequiredDistribution(inputPlan)
}
// A list of physical plan rules to be applied before creation of query stages. The physical
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
RemoveRedundantProjects,
EnsureRequirements,
// For cases like `df.repartition(a, b).select(c)`, there is no distribution requirement for
// the final plan, but we do need to respect the user-specified repartition. Here we ask
// `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
// around this case.
EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined),
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan
) ++ context.session.sessionState.queryStagePrepRules
@ -114,33 +131,24 @@ case class AdaptiveSparkPlanExec(
CollapseCodegenStages()
) ++ context.session.sessionState.postStageCreationRules
// The partitioning of the query output depends on the shuffle(s) in the final stage. If the
// original plan contains a repartition operator, we need to preserve the specified partitioning,
// whether or not the repartition-introduced shuffle is optimized out because of an underlying
// shuffle of the same partitioning. Thus, we need to exclude some `AQEShuffleReadRule`s
// from the final stage, depending on the presence and properties of repartition operators.
private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] = {
val origins = inputPlan.collect {
case s: ShuffleExchangeLike => s.shuffleOrigin
}
val allRules = queryStageOptimizerRules ++ postStageCreationRules
allRules.filter {
case c: AQEShuffleReadRule =>
origins.forall(c.supportedShuffleOrigins.contains)
case _ => true
}
}
private def optimizeQueryStage(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case c: AQEShuffleReadRule if c.mayAddExtraShuffles =>
if (ValidateRequirements.validate(applied)) {
case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
val distribution = if (isFinalStage) {
// If `requiredDistribution` is None, it means `EnsureRequirements` will not optimize
// out the user-specified repartition, thus we don't have a distribution requirement
// for the final plan.
requiredDistribution.getOrElse(UnspecifiedDistribution)
} else {
UnspecifiedDistribution
}
if (ValidateRequirements.validate(applied, distribution)) {
applied
} else {
logDebug(s"Rule ${rule.ruleName} is not applied due to additional shuffles " +
"will be introduced.")
logDebug(s"Rule ${rule.ruleName} is not applied as it breaks the " +
"distribution requirement of the query plan.")
latestPlan
}
case _ => applied
@ -303,7 +311,10 @@ case class AdaptiveSparkPlanExec(
}
// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = optimizeQueryStage(result.newPlan, finalStageOptimizerRules)
currentPhysicalPlan = applyPhysicalRules(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules,
Some((planChangeLogger, "AQE Post Stage Creation")))
isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
@ -520,7 +531,7 @@ case class AdaptiveSparkPlanExec(
}
private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = optimizeQueryStage(e.child, queryStageOptimizerRules)
val optimizedPlan = optimizeQueryStage(e.child, isFinalStage = false)
val queryStage = e match {
case s: ShuffleExchangeLike =>
val newShuffle = applyPhysicalRules(

View file

@ -33,6 +33,10 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
Seq(ENSURE_REQUIREMENTS, REPARTITION_BY_COL, REBALANCE_PARTITIONS_BY_NONE,
REBALANCE_PARTITIONS_BY_COL)
override def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle)
}
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.coalesceShufflePartitionsEnabled) {
return plan
@ -52,7 +56,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
val shuffleStageInfos = collectShuffleStageInfos(plan)
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
if (!shuffleStageInfos.forall(s => supportCoalesce(s.shuffleStage.shuffle))) {
if (!shuffleStageInfos.forall(s => isSupported(s.shuffleStage.shuffle))) {
plan
} else {
// Ideally, this rule should simply coalesce partition w.r.t. the target size specified by
@ -106,10 +110,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
}.getOrElse(plan)
case other => other.mapChildren(updateShuffleReads(_, specsMap))
}
private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
}
}
private class ShuffleStageInfo(

View file

@ -38,7 +38,9 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule {
override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
Seq(ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_NONE)
override def mayAddExtraShuffles: Boolean = true
override protected def isSupported(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle)
}
// The build side is a broadcast query stage which should have been optimized using local read
// already. So we only need to deal with probe side here.
@ -136,14 +138,10 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule {
def canUseLocalShuffleRead(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.mapStats.isDefined && supportLocalRead(s.shuffle)
s.mapStats.isDefined && isSupported(s.shuffle)
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) =>
s.mapStats.isDefined && supportLocalRead(s.shuffle) &&
s.mapStats.isDefined && isSupported(s.shuffle) &&
s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS
case _ => false
}
private def supportLocalRead(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
}
}

View file

@ -38,7 +38,8 @@ import org.apache.spark.sql.internal.SQLConf
* ShuffleQueryStageExec.
*/
object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
override def supportedShuffleOrigins: Seq[ShuffleOrigin] =
override val supportedShuffleOrigins: Seq[ShuffleOrigin] =
Seq(REBALANCE_PARTITIONS_BY_NONE, REBALANCE_PARTITIONS_BY_COL)
/**
@ -92,9 +93,8 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
}
plan match {
case shuffle: ShuffleQueryStageExec
if supportedShuffleOrigins.contains(shuffle.shuffle.shuffleOrigin) =>
tryOptimizeSkewedPartitions(shuffle)
case stage: ShuffleQueryStageExec if isSupported(stage.shuffle) =>
tryOptimizeSkewedPartitions(stage)
case _ => plan
}
}

View file

@ -52,8 +52,6 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
override def mayAddExtraShuffles: Boolean = true
/**
* A partition is considered as a skewed partition if its size is larger than the median
* partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than
@ -257,13 +255,12 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
plan
}
}
}
private object ShuffleStage {
def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
case s: ShuffleQueryStageExec if s.mapStats.isDefined &&
OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) =>
Some(s)
case _ => None
object ShuffleStage {
def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
case s: ShuffleQueryStageExec if s.mapStats.isDefined && isSupported(s.shuffle) =>
Some(s)
case _ => None
}
}
}

View file

@ -32,8 +32,14 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that
* the input partition ordering requirements are met.
*
* @param optimizeOutRepartition A flag to indicate that if this rule should optimize out
* user-specified repartition shuffles or not. This is mostly true,
* but can be false in AQE when AQE optimization may change the plan
* output partitioning and need to retain the user-specified
* repartition shuffles in the plan.
*/
object EnsureRequirements extends Rule[SparkPlan] {
case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
@ -249,13 +255,9 @@ object EnsureRequirements extends Rule[SparkPlan] {
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
// SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM
// shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL,
// it is a special shuffle used to rebalance partitions.
// So, here we only remove REPARTITION_BY_COL in AQE.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled =>
if optimizeOutRepartition &&
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true

View file

@ -30,6 +30,10 @@ import org.apache.spark.sql.execution._
*/
object ValidateRequirements extends Logging {
def validate(plan: SparkPlan, requiredDistribution: Distribution): Boolean = {
validate(plan) && plan.outputPartitioning.satisfies(requiredDistribution)
}
def validate(plan: SparkPlan): Boolean = {
plan.children.forall(validate) && validateInternal(plan)
}

View file

@ -40,6 +40,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
setupTestData()
private val EnsureRequirements = new EnsureRequirements()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val planner = spark.sessionState.planner
import planner._

View file

@ -1496,6 +1496,60 @@ class AdaptiveQueryExecSuite
}.isDefined
}
def checkBHJ(
df: Dataset[Row],
optimizeOutRepartition: Boolean,
probeSideLocalRead: Boolean,
probeSideCoalescedRead: Boolean): Unit = {
df.collect()
val plan = df.queryExecution.executedPlan
// There should be only one shuffle that can't do local read, which is either the top shuffle
// from repartition, or BHJ probe side shuffle.
checkNumLocalShuffleReads(plan, 1)
assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
val bhj = findTopLevelBroadcastHashJoin(plan)
assert(bhj.length == 1)
// Build side should do local read.
val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec])
assert(buildSide.isDefined)
assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec])
if (probeSideLocalRead || probeSideCoalescedRead) {
assert(probeSide.isDefined)
if (probeSideLocalRead) {
assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
} else {
assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
}
} else {
assert(probeSide.isEmpty)
}
}
def checkSMJ(
df: Dataset[Row],
optimizeOutRepartition: Boolean,
optimizeSkewJoin: Boolean,
coalescedRead: Boolean): Unit = {
df.collect()
val plan = df.queryExecution.executedPlan
assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.length == 1)
assert(smj.head.isSkewJoin == optimizeSkewJoin)
val aqeReads = collect(smj.head) {
case c: AQEShuffleReadExec => c
}
if (coalescedRead || optimizeSkewJoin) {
assert(aqeReads.length == 2)
if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition))
} else {
assert(aqeReads.isEmpty)
}
}
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val df = sql(
@ -1509,44 +1563,25 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
// Repartition with no partition num specified.
val dfRepartition = df.repartition('b)
dfRepartition.collect()
val plan = dfRepartition.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(plan))
val bhj = findTopLevelBroadcastHashJoin(plan)
assert(bhj.length == 1)
checkNumLocalShuffleReads(plan, 1)
// Probe side is coalesced.
val aqeRead = bhj.head.right.find(_.isInstanceOf[AQEShuffleReadExec])
assert(aqeRead.isDefined)
assert(aqeRead.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
checkBHJ(df.repartition('b),
// The top shuffle from repartition is optimized out.
optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true)
// Repartition with partition default num specified.
val dfRepartitionWithNum = df.repartition(5, 'b)
dfRepartitionWithNum.collect()
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
// The top shuffle from repartition is not optimized out.
assert(hasRepartitionShuffle(planWithNum))
val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
assert(bhjWithNum.length == 1)
checkNumLocalShuffleReads(planWithNum, 1)
// Probe side is coalesced.
assert(bhjWithNum.head.right.find(_.isInstanceOf[AQEShuffleReadExec]).nonEmpty)
// Repartition with default partition num (5 in test env) specified.
checkBHJ(df.repartition(5, 'b),
// The top shuffle from repartition is optimized out
// The final plan must have 5 partitions, no optimization can be made to the probe side.
optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false)
// Repartition with partition non-default num specified.
val dfRepartitionWithNum2 = df.repartition(3, 'b)
dfRepartitionWithNum2.collect()
val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
// The top shuffle from repartition is not optimized out, and this is the only shuffle that
// does not have local shuffle read.
assert(hasRepartitionShuffle(planWithNum2))
val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2)
assert(bhjWithNum2.length == 1)
checkNumLocalShuffleReads(planWithNum2, 1)
val aqeRead2 = bhjWithNum2.head.right.find(_.isInstanceOf[AQEShuffleReadExec])
assert(aqeRead2.isDefined)
assert(aqeRead2.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
// Repartition with non-default partition num specified.
checkBHJ(df.repartition(4, 'b),
// The top shuffle from repartition is not optimized out
optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
// Repartition by col and project away the partition cols
checkBHJ(df.repartition('b).select('key),
// The top shuffle from repartition is not optimized out
optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
}
// Force skew join
@ -1556,46 +1591,25 @@ class AdaptiveQueryExecSuite
SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
// Repartition with no partition num specified.
val dfRepartition = df.repartition('b)
dfRepartition.collect()
val plan = dfRepartition.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(plan))
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.length == 1)
// No skew join due to the repartition.
assert(!smj.head.isSkewJoin)
// Both sides are coalesced.
val aqeReads = collect(smj.head) {
case c: AQEShuffleReadExec if c.hasCoalescedPartition => c
}
assert(aqeReads.length == 2)
checkSMJ(df.repartition('b),
// The top shuffle from repartition is optimized out.
optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true)
// Repartition with default partition num specified.
val dfRepartitionWithNum = df.repartition(5, 'b)
dfRepartitionWithNum.collect()
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
// The top shuffle from repartition is not optimized out.
assert(hasRepartitionShuffle(planWithNum))
val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
assert(smjWithNum.length == 1)
// Skew join can apply as the repartition is not optimized out.
assert(smjWithNum.head.isSkewJoin)
val aqeReadsWithNum = collect(smjWithNum.head) {
case c: AQEShuffleReadExec => c
}
assert(aqeReadsWithNum.nonEmpty)
// Repartition with default partition num (5 in test env) specified.
checkSMJ(df.repartition(5, 'b),
// The top shuffle from repartition is optimized out.
// The final plan must have 5 partitions, can't do coalesced read.
optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false)
// Repartition with default non-partition num specified.
val dfRepartitionWithNum2 = df.repartition(3, 'b)
dfRepartitionWithNum2.collect()
val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
// The top shuffle from repartition is not optimized out.
assert(hasRepartitionShuffle(planWithNum2))
val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2)
assert(smjWithNum2.length == 1)
// Skew join can apply as the repartition is not optimized out.
assert(smjWithNum2.head.isSkewJoin)
// Repartition with non-default partition num specified.
checkSMJ(df.repartition(4, 'b),
// The top shuffle from repartition is not optimized out.
optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
// Repartition by col and project away the partition cols
checkSMJ(df.repartition('b).select('key),
// The top shuffle from repartition is not optimized out.
optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
}
}
}

View file

@ -21,16 +21,17 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
private val exprB = Literal(2)
private val exprC = Literal(3)
private val EnsureRequirements = new EnsureRequirements()
test("reorder should handle PartitioningCollection") {
val plan1 = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(
@ -134,26 +135,4 @@ class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanH
}.size == 2)
}
}
test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") {
import testImplicits._
Seq(true, false).foreach { enableAqe =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "3",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = Seq((1, 2)).toDF("c1", "c2")
val df2 = Seq((1, 3)).toDF("c3", "c4")
val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1")
val num = collect(res.queryExecution.executedPlan) {
case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM =>
shuffle
}.size
if (enableAqe) {
assert(num == 1)
} else {
assert(num == 0)
}
}
}
}
}

View file

@ -48,6 +48,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
protected var spark: SparkSession = null
private val EnsureRequirements = new EnsureRequirements()
/**
* Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled.
*/

View file

@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructT
class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
private val EnsureRequirements = new EnsureRequirements()
private lazy val left = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, 2.0),

View file

@ -33,6 +33,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
import testImplicits.newProductEncoder
import testImplicits.localSeqToDatasetHolder
private val EnsureRequirements = new EnsureRequirements()
private lazy val myUpperCaseData = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, "A"),

View file

@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
private val EnsureRequirements = new EnsureRequirements()
private lazy val left = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, 2.0),