[SPARK-35989][SQL] Only remove redundant shuffle if shuffle origin is REPARTITION_BY_COL in AQE

### What changes were proposed in this pull request?

Skip remove shuffle if it's shuffle origin is not `REPARTITION_BY_COL` in AQE.

### Why are the changes needed?

`REPARTITION_BY_COL` doesn't guarantee the output partitioning number so we can remove it safely in AQE.

For `REPARTITION_BY_NUM`, we should retain the shuffle which partition number is specified by user.
For `REBALANCE_PARTITIONS_BY_COL`, it is a special shuffle used to rebalance partitions so we should not remove it.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

add test

Closes #33188 from ulysses-you/SPARK-35989.

Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: ulysses <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 7fe4c4a9ad)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
ulysses-you 2021-07-05 17:10:42 +08:00 committed by Wenchen Fan
parent 39b3a04bfe
commit ed7c81dfaa
4 changed files with 42 additions and 14 deletions

View file

@ -250,7 +250,12 @@ object EnsureRequirements extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
// SPARK-35989: AQE will change the partition number so we should retain the REPARTITION_BY_NUM
// shuffle which is specified by user. And also we can not remove REBALANCE_PARTITIONS_BY_COL,
// it is a special shuffle used to rebalance partitions.
// So, here we only remove REPARTITION_BY_COL in AQE.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
if shuffleOrigin == REPARTITION_BY_COL || !conf.adaptiveExecutionEnabled =>
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.functions._
@ -420,7 +420,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
val inputPlan = ShuffleExchangeExec(
partitioning,
DummySparkPlan(outputPartitioning = partitioning))
DummySparkPlan(outputPartitioning = partitioning),
REPARTITION_BY_COL)
val outputPlan = EnsureRequirements.apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {

View file

@ -1526,13 +1526,13 @@ class AdaptiveQueryExecSuite
val dfRepartitionWithNum = df.repartition(5, 'b)
dfRepartitionWithNum.collect()
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(planWithNum))
// The top shuffle from repartition is not optimized out.
assert(hasRepartitionShuffle(planWithNum))
val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
assert(bhjWithNum.length == 1)
checkNumLocalShuffleReaders(planWithNum, 1)
// Probe side is not coalesced.
assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty)
// Probe side is coalesced.
assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).nonEmpty)
// Repartition with partition non-default num specified.
val dfRepartitionWithNum2 = df.repartition(3, 'b)
@ -1575,17 +1575,16 @@ class AdaptiveQueryExecSuite
val dfRepartitionWithNum = df.repartition(5, 'b)
dfRepartitionWithNum.collect()
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(planWithNum))
// The top shuffle from repartition is not optimized out.
assert(hasRepartitionShuffle(planWithNum))
val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
assert(smjWithNum.length == 1)
// No skew join due to the repartition.
assert(!smjWithNum.head.isSkewJoin)
// No coalesce due to the num in repartition.
// Skew join can apply as the repartition is not optimized out.
assert(smjWithNum.head.isSkewJoin)
val customReadersWithNum = collect(smjWithNum.head) {
case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
}
assert(customReadersWithNum.isEmpty)
assert(customReadersWithNum.nonEmpty)
// Repartition with default non-partition num specified.
val dfRepartitionWithNum2 = df.repartition(3, 'b)

View file

@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
class EnsureRequirementsSuite extends SharedSparkSession {
class EnsureRequirementsSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
private val exprA = Literal(1)
private val exprB = Literal(2)
private val exprC = Literal(3)
@ -133,4 +134,26 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}.size == 2)
}
}
test("SPARK-35989: Do not remove REPARTITION_BY_NUM shuffle if AQE is enabled") {
import testImplicits._
Seq(true, false).foreach { enableAqe =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAqe.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "3",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = Seq((1, 2)).toDF("c1", "c2")
val df2 = Seq((1, 3)).toDF("c3", "c4")
val res = df1.join(df2, $"c1" === $"c3").repartition(3, $"c1")
val num = collect(res.queryExecution.executedPlan) {
case shuffle: ShuffleExchangeExec if shuffle.shuffleOrigin == REPARTITION_BY_NUM =>
shuffle
}.size
if (enableAqe) {
assert(num == 1)
} else {
assert(num == 0)
}
}
}
}
}