[SPARK-35989][SQL] Only remove redundant shuffle if shuffle origin is REPARTITION_BY_COL in AQE
### What changes were proposed in this pull request? Skip remove shuffle if it's shuffle origin is not `REPARTITION_BY_COL` in AQE. ### Why are the changes needed? `REPARTITION_BY_COL` doesn't guarantee the output partitioning number so we can remove it safely in AQE. For `REPARTITION_BY_NUM`, we should retain the shuffle which partition number is specified by user. For `REBALANCE_PARTITIONS_BY_COL`, it is a special shuffle used to rebalance partitions so we should not remove it. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? add test Closes #33188 from ulysses-you/SPARK-35989. Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com> Co-authored-by: ulysses <ulyssesyou18@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
044dddf288
commit
7fe4c4a9ad
|
@ -250,7 +250,12 @@ object EnsureRequirements extends Rule[SparkPlan] {
|
||||||
|
|
||||||
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
|
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
|
||||||
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
|
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
|
||||||
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
|
// SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM
|
||||||
|
// shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL,
|
||||||
|
// it is a special shuffle used to rebalance partitions.
|
||||||
|
// So, here we only remove REPARTITION_BY_COL in AQE.
|
||||||
|
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
|
||||||
|
if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled =>
|
||||||
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
|
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
|
||||||
partitioning match {
|
partitioning match {
|
||||||
case lower: HashPartitioning if upper.semanticEquals(lower) => true
|
case lower: HashPartitioning if upper.semanticEquals(lower) => true
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
|
||||||
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
|
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
|
||||||
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
|
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
|
||||||
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
|
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
|
||||||
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ShuffleExchangeExec}
|
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec}
|
||||||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
|
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
|
||||||
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
|
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
@ -420,7 +420,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
|
||||||
|
|
||||||
val inputPlan = ShuffleExchangeExec(
|
val inputPlan = ShuffleExchangeExec(
|
||||||
partitioning,
|
partitioning,
|
||||||
DummySparkPlan(outputPartitioning = partitioning))
|
DummySparkPlan(outputPartitioning = partitioning),
|
||||||
|
REPARTITION_BY_COL)
|
||||||
val outputPlan = EnsureRequirements.apply(inputPlan)
|
val outputPlan = EnsureRequirements.apply(inputPlan)
|
||||||
assertDistributionRequirementsAreSatisfied(outputPlan)
|
assertDistributionRequirementsAreSatisfied(outputPlan)
|
||||||
if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {
|
if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {
|
||||||
|
|
|
@ -1526,13 +1526,13 @@ class AdaptiveQueryExecSuite
|
||||||
val dfRepartitionWithNum = df.repartition(5, 'b)
|
val dfRepartitionWithNum = df.repartition(5, 'b)
|
||||||
dfRepartitionWithNum.collect()
|
dfRepartitionWithNum.collect()
|
||||||
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
|
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
|
||||||
// The top shuffle from repartition is optimized out.
|
// The top shuffle from repartition is not optimized out.
|
||||||
assert(!hasRepartitionShuffle(planWithNum))
|
assert(hasRepartitionShuffle(planWithNum))
|
||||||
val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
|
val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
|
||||||
assert(bhjWithNum.length == 1)
|
assert(bhjWithNum.length == 1)
|
||||||
checkNumLocalShuffleReaders(planWithNum, 1)
|
checkNumLocalShuffleReaders(planWithNum, 1)
|
||||||
// Probe side is not coalesced.
|
// Probe side is coalesced.
|
||||||
assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty)
|
assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).nonEmpty)
|
||||||
|
|
||||||
// Repartition with partition non-default num specified.
|
// Repartition with partition non-default num specified.
|
||||||
val dfRepartitionWithNum2 = df.repartition(3, 'b)
|
val dfRepartitionWithNum2 = df.repartition(3, 'b)
|
||||||
|
@ -1575,17 +1575,16 @@ class AdaptiveQueryExecSuite
|
||||||
val dfRepartitionWithNum = df.repartition(5, 'b)
|
val dfRepartitionWithNum = df.repartition(5, 'b)
|
||||||
dfRepartitionWithNum.collect()
|
dfRepartitionWithNum.collect()
|
||||||
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
|
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
|
||||||
// The top shuffle from repartition is optimized out.
|
// The top shuffle from repartition is not optimized out.
|
||||||
assert(!hasRepartitionShuffle(planWithNum))
|
assert(hasRepartitionShuffle(planWithNum))
|
||||||
val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
|
val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
|
||||||
assert(smjWithNum.length == 1)
|
assert(smjWithNum.length == 1)
|
||||||
// No skew join due to the repartition.
|
// Skew join can apply as the repartition is not optimized out.
|
||||||
assert(!smjWithNum.head.isSkewJoin)
|
assert(smjWithNum.head.isSkewJoin)
|
||||||
// No coalesce due to the num in repartition.
|
|
||||||
val customReadersWithNum = collect(smjWithNum.head) {
|
val customReadersWithNum = collect(smjWithNum.head) {
|
||||||
case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
|
case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
|
||||||
}
|
}
|
||||||
assert(customReadersWithNum.isEmpty)
|
assert(customReadersWithNum.nonEmpty)
|
||||||
|
|
||||||
// Repartition with default non-partition num specified.
|
// Repartition with default non-partition num specified.
|
||||||
val dfRepartitionWithNum2 = df.repartition(3, 'b)
|
val dfRepartitionWithNum2 = df.repartition(3, 'b)
|
||||||
|
|
|
@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.expressions.Literal
|
||||||
import org.apache.spark.sql.catalyst.plans.Inner
|
import org.apache.spark.sql.catalyst.plans.Inner
|
||||||
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
|
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
|
||||||
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
|
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
|
||||||
|
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
||||||
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
|
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.test.SharedSparkSession
|
import org.apache.spark.sql.test.SharedSparkSession
|
||||||
|
|
||||||
class EnsureRequirementsSuite extends SharedSparkSession {
|
class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
|
||||||
private val exprA = Literal(1)
|
private val exprA = Literal(1)
|
||||||
private val exprB = Literal(2)
|
private val exprB = Literal(2)
|
||||||
private val exprC = Literal(3)
|
private val exprC = Literal(3)
|
||||||
|
@ -133,4 +134,26 @@ class EnsureRequirementsSuite extends SharedSparkSession {
|
||||||
}.size == 2)
|
}.size == 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") {
|
||||||
|
import testImplicits._
|
||||||
|
Seq(true, false).foreach { enableAqe =>
|
||||||
|
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString,
|
||||||
|
SQLConf.SHUFFLE_PARTITIONS.key -> "3",
|
||||||
|
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
|
||||||
|
val df1 = Seq((1, 2)).toDF("c1", "c2")
|
||||||
|
val df2 = Seq((1, 3)).toDF("c3", "c4")
|
||||||
|
val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1")
|
||||||
|
val num = collect(res.queryExecution.executedPlan) {
|
||||||
|
case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM =>
|
||||||
|
shuffle
|
||||||
|
}.size
|
||||||
|
if (enableAqe) {
|
||||||
|
assert(num == 1)
|
||||||
|
} else {
|
||||||
|
assert(num == 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue