[SPARK-35675][SQL] EnsureRequirements remove shuffle should respect PartitioningCollection

### What changes were proposed in this pull request?

Add `PartitioningCollection` in EnsureRequirements during remove shuffle.

### Why are the changes needed?

Currently `EnsureRequirements` only check if child has semantic equal `HashPartitioning` and remove
redundant shuffle. We can enhance this case using `PartitioningCollection`.

### Does this PR introduce _any_ user-facing change?

Yes, plan might be changed.

### How was this patch tested?

Add test.

Closes #32815 from ulysses-you/shuffle-node.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
ulysses-you 2021-06-10 13:03:47 +08:00 committed by Kent Yao
parent 87d2ffbbcf
commit 8dde20a993
2 changed files with 27 additions and 3 deletions

View file

@ -251,10 +251,20 @@ object EnsureRequirements extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
child.outputPartitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => child
case _ => operator
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true
case lower: PartitioningCollection =>
lower.partitionings.exists(hasSemanticEqualPartitioning)
case _ => false
}
}
if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
child
} else {
operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
}

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
class EnsureRequirementsSuite extends SharedSparkSession {
@ -119,4 +120,17 @@ class EnsureRequirementsSuite extends SharedSparkSession {
case other => fail(other.toString)
}
}
test("SPARK-35675: EnsureRequirements remove shuffle should respect PartitioningCollection") {
import testImplicits._
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df1 = Seq((1, 2)).toDF("c1", "c2")
val df2 = Seq((1, 3)).toDF("c3", "c4")
val res = df1.join(df2, $"c1" === $"c3").repartition($"c1")
assert(res.queryExecution.executedPlan.collect {
case s: ShuffleExchangeLike => s
}.size == 2)
}
}
}