[SPARK-33472][SQL] Adjust RemoveRedundantSorts rule order
### What changes were proposed in this pull request? This PR switched the order for the rule `RemoveRedundantSorts` and `EnsureRequirements` so that `EnsureRequirements` will be invoked before `RemoveRedundantSorts` to avoid IllegalArgumentException when instantiating PartitioningCollection. ### Why are the changes needed? `RemoveRedundantSorts` rule uses SparkPlan's `outputPartitioning` to check whether a sort node is redundant. Currently, it is added before `EnsureRequirements`. Since `PartitioningCollection` requires left and right partitioning to have the same number of partitions, which is not necessarily true before applying `EnsureRequirements`, the rule can fail with the following exception: ``` IllegalArgumentException: requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions. ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #30373 from allisonwang-db/sort-follow-up. Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
ef2638c3e3
commit
a03c540cf7
|
@ -343,8 +343,10 @@ object QueryExecution {
|
|||
PlanDynamicPruningFilters(sparkSession),
|
||||
PlanSubqueries(sparkSession),
|
||||
RemoveRedundantProjects,
|
||||
RemoveRedundantSorts,
|
||||
EnsureRequirements,
|
||||
// `RemoveRedundantSorts` needs to be added before `EnsureRequirements` to guarantee the same
|
||||
// number of partitions when instantiating PartitioningCollection.
|
||||
RemoveRedundantSorts,
|
||||
DisableUnnecessaryBucketedScan,
|
||||
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.columnarRules),
|
||||
CollapseCodegenStages(),
|
||||
|
|
|
@ -135,7 +135,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
|
|||
def longMetric(name: String): SQLMetric = metrics(name)
|
||||
|
||||
// TODO: Move to `DistributedPlan`
|
||||
/** Specifies how data is partitioned across different nodes in the cluster. */
|
||||
/**
|
||||
* Specifies how data is partitioned across different nodes in the cluster.
|
||||
* Note this method may fail if it is invoked before `EnsureRequirements` is applied
|
||||
* since `PartitioningCollection` requires all its partitionings to have
|
||||
* the same number of partitions.
|
||||
*/
|
||||
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
|
||||
|
||||
/**
|
||||
|
|
|
@ -88,8 +88,8 @@ case class AdaptiveSparkPlanExec(
|
|||
// Exchange nodes) after running these rules.
|
||||
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
|
||||
RemoveRedundantProjects,
|
||||
RemoveRedundantSorts,
|
||||
EnsureRequirements,
|
||||
RemoveRedundantSorts,
|
||||
DisableUnnecessaryBucketedScan
|
||||
) ++ context.session.sessionState.queryStagePrepRules
|
||||
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.{DataFrame, QueryTest}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
|
||||
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
|
||||
import org.apache.spark.sql.execution.joins.ShuffledJoin
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
|
||||
|
@ -135,6 +137,32 @@ abstract class RemoveRedundantSortsSuiteBase
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-33472: shuffled join with different left and right side partition numbers") {
|
||||
withTempView("t1", "t2") {
|
||||
spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1")
|
||||
(0 to 100).toDF("key").createOrReplaceTempView("t2")
|
||||
|
||||
val queryTemplate = """
|
||||
|SELECT /*+ %s(t1) */ t1.key
|
||||
|FROM t1 JOIN t2 ON t1.key = t2.key
|
||||
|WHERE t1.key > 10 AND t2.key < 50
|
||||
|ORDER BY t1.key ASC
|
||||
""".stripMargin
|
||||
|
||||
Seq(("MERGE", 3), ("SHUFFLE_HASH", 1)).foreach { case (hint, count) =>
|
||||
val query = queryTemplate.format(hint)
|
||||
val df = sql(query)
|
||||
val sparkPlan = df.queryExecution.sparkPlan
|
||||
val join = sparkPlan.collect { case j: ShuffledJoin => j }.head
|
||||
val leftPartitioning = join.left.outputPartitioning
|
||||
assert(leftPartitioning.isInstanceOf[RangePartitioning])
|
||||
assert(leftPartitioning.numPartitions == 2)
|
||||
assert(join.right.outputPartitioning == UnknownPartitioning(0))
|
||||
checkSorts(query, count, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class RemoveRedundantSortsSuite extends RemoveRedundantSortsSuiteBase
|
||||
|
|
Loading…
Reference in a new issue