[SPARK-35455][SQL] Unify empty relation optimization between normal and AQE optimizer

### What changes were proposed in this pull request?

* remove `EliminateUnnecessaryJoin`, using `AQEPropagateEmptyRelation` instead.
* eliminate join, aggregate, limit, repartition, sort, generate which is beneficial.

### Why are the changes needed?

Make `EliminateUnnecessaryJoin` available with more case.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Add test.

Closes #32602 from ulysses-you/SPARK-35455.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
ulysses-you 2021-05-25 08:59:59 +00:00 committed by Wenchen Fan
parent 4a6d844184
commit 631077db08
6 changed files with 215 additions and 136 deletions

View file

@ -26,59 +26,46 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL}
/**
* Collapse plans consisting empty local relations generated by [[PruneFilters]].
* 1. Binary(or Higher)-node Logical Plans
* - Union with all empty children.
* - Join with one or two empty children (including Intersect/Except).
* 2. Unary-node Logical Plans
* - Project/Filter/Sample/Join/Limit/Repartition with all empty children.
* - Join with false condition.
* - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
* The base class of two rules in the normal and AQE Optimizer. It simplifies query plans with
* empty or non-empty relations:
* 1. Binary-node Logical Plans
* - Join with one or two empty children (including Intersect/Except).
* - Left semi Join
* Right side is non-empty and condition is empty. Eliminate join to its left side.
* - Left anti join
* Right side is non-empty and condition is empty. Eliminate join to an empty
* [[LocalRelation]].
* 2. Unary-node Logical Plans
* - Limit/Repartition with all empty children.
* - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport {
private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match {
abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
}
private def empty(plan: LogicalPlan) =
protected def nonEmpty(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.nonEmpty
case _ => false
}
protected def empty(plan: LogicalPlan): LocalRelation =
LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming)
// Construct a project list from plan's output, while the value is always NULL.
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
case p: Union if p.children.exists(isEmptyLocalRelation) =>
val newChildren = p.children.filterNot(isEmptyLocalRelation)
if (newChildren.isEmpty) {
empty(p)
} else {
val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head
val outputs = newPlan.output.zip(p.output)
// the original Union may produce different output attributes than the new one so we alias
// them if needed
if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) {
newPlan
} else {
val outputAliases = outputs.map { case (newAttr, oldAttr) =>
val newExplicitMetadata =
if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None
Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata)
}
Project(outputAliases, newPlan)
}
}
protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
// Joins on empty LocalRelations generated from streaming sources are not eliminated
// as stateful streaming joins need to perform other state management operations other than
// just processing the input data.
case p @ Join(_, _, joinType, conditionOpt, _)
if !p.children.exists(_.isStreaming) =>
val isLeftEmpty = isEmptyLocalRelation(p.left)
val isRightEmpty = isEmptyLocalRelation(p.right)
val isLeftEmpty = isEmpty(p.left)
val isRightEmpty = isEmpty(p.right)
val isFalseCondition = conditionOpt match {
case Some(FalseLiteral) => true
case _ => false
@ -103,14 +90,15 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit
Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
case _ => p
}
} else if (joinType == LeftSemi && conditionOpt.isEmpty && nonEmpty(p.right)) {
p.left
} else if (joinType == LeftAnti && conditionOpt.isEmpty && nonEmpty(p.right)) {
empty(p)
} else {
p
}
case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmptyLocalRelation) => p match {
case _: Project => empty(p)
case _: Filter => empty(p)
case _: Sample => empty(p)
case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmpty) => p match {
case _: Sort => empty(p)
case _: GlobalLimit if !p.isStreaming => empty(p)
case _: LocalLimit if !p.isStreaming => empty(p)
@ -137,3 +125,55 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit
}
}
}
/**
* This rule runs in the normal optimizer and optimizes more cases
* compared to [[PropagateEmptyRelationBase]]:
* 1. Higher-node Logical Plans
* - Union with all empty children.
* 2. Unary-node Logical Plans
* - Project/Filter/Sample with all empty children.
*
* The reason why we don't apply this rule at AQE optimizer side is: the benefit is not big enough
* and it may introduce extra exchanges.
*/
object PropagateEmptyRelation extends PropagateEmptyRelationBase {
private def applyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
case p: Union if p.children.exists(isEmpty) =>
val newChildren = p.children.filterNot(isEmpty)
if (newChildren.isEmpty) {
empty(p)
} else {
val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head
val outputs = newPlan.output.zip(p.output)
// the original Union may produce different output attributes than the new one so we alias
// them if needed
if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) {
newPlan
} else {
val outputAliases = outputs.map { case (newAttr, oldAttr) =>
val newExplicitMetadata =
if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None
Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata)
}
Project(outputAliases, newPlan)
}
}
case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmpty) && canPropagate(p) =>
empty(p)
}
// extract the pattern avoid conflict with commonApplyFunc
private def canPropagate(plan: LogicalPlan): Boolean = plan match {
case _: Project => true
case _: Filter => true
case _: Sample => true
case _ => false
}
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
applyFunc.orElse(commonApplyFunc)
}
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
@ -27,7 +28,9 @@ import org.apache.spark.util.Utils
*/
class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
private val defaultBatches = Seq(
Batch("Eliminate Unnecessary Join", Once, EliminateUnnecessaryJoin),
Batch("Propagate Empty Relations", Once,
AQEPropagateEmptyRelation,
UpdateAttributeNullability),
Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin)
)

View file

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
/**
* This rule runs in the AQE optimizer and optimizes more cases
* compared to [[PropagateEmptyRelationBase]]:
* 1. Join is single column NULL-aware anti join (NAAJ)
* Broadcasted [[HashedRelation]] is [[HashedRelationWithAllNullKeys]]. Eliminate join to an
* empty [[LocalRelation]].
*/
object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
override protected def isEmpty(plan: LogicalPlan): Boolean =
super.isEmpty(plan) || getRowCount(plan).contains(0)
override protected def nonEmpty(plan: LogicalPlan): Boolean =
super.nonEmpty(plan) || getRowCount(plan).exists(_ > 0)
private def getRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
case LogicalQueryStage(_, stage: QueryStageExec) if stage.resultOption.get().isDefined =>
stage.getRuntimeStatistics.rowCount
case _ => None
}
private def isRelationWithAllNullKeys(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec)
if stage.resultOption.get().isDefined =>
stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
case _ => false
}
private def eliminateSingleColumnNullAwareAntiJoin: PartialFunction[LogicalPlan, LogicalPlan] = {
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) =>
empty(j)
}
// TODO we need use transformUpWithPruning instead of transformUp
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
eliminateSingleColumnNullAwareAntiJoin.orElse(commonApplyFunc)
}
}

View file

@ -1,91 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
/**
* This optimization rule detects and eliminates unnecessary Join:
* 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]]
* is [[HashedRelationWithAllNullKeys]]. Eliminate join to an empty [[LocalRelation]].
*
* 2. Join is inner join, and either side of join is empty. Eliminate join to an empty
* [[LocalRelation]].
*
* 3. Join is left semi join
* 3.1. Join right side is empty. Eliminate join to an empty [[LocalRelation]].
* 3.2. Join right side is non-empty and condition is empty. Eliminate join to its left side.
*
* 4. Join is left anti join
* 4.1. Join right side is empty. Eliminate join to its left side.
* 4.2. Join right side is non-empty and condition is empty. Eliminate join to an empty
* [[LocalRelation]].
*
* This applies to all joins (sort merge join, shuffled hash join, broadcast hash join, and
* broadcast nested loop join), because sort merge join and shuffled hash join will be changed
* to broadcast hash join with AQE at the first place.
*/
object EliminateUnnecessaryJoin extends Rule[LogicalPlan] {
private def isRelationWithAllNullKeys(plan: LogicalPlan) = plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec)
if stage.resultOption.get().isDefined =>
stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
case _ => false
}
private def checkRowCount(plan: LogicalPlan, hasRow: Boolean): Boolean = plan match {
case LogicalQueryStage(_, stage: QueryStageExec) if stage.resultOption.get().isDefined =>
stage.getRuntimeStatistics.rowCount match {
case Some(count) => hasRow == (count > 0)
case _ => false
}
case _ => false
}
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) =>
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
case j @ Join(_, _, Inner, _, _) if checkRowCount(j.left, hasRow = false) ||
checkRowCount(j.right, hasRow = false) =>
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
case j @ Join(_, _, LeftSemi, condition, _) =>
if (checkRowCount(j.right, hasRow = false)) {
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
} else if (condition.isEmpty && checkRowCount(j.right, hasRow = true)) {
j.left
} else {
j
}
case j @ Join(_, _, LeftAnti, condition, _) =>
if (checkRowCount(j.right, hasRow = false)) {
j.left
} else if (condition.isEmpty && checkRowCount(j.right, hasRow = true)) {
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
} else {
j
}
}
}

View file

@ -1382,7 +1382,7 @@ abstract class DynamicPartitionPruningSuiteBase
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> EliminateUnnecessaryJoin.ruleName) {
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
val df = sql(
"""
|SELECT * FROM fact_sk f

View file

@ -236,7 +236,8 @@ class AdaptiveQueryExecSuite
test("Empty stage coalesced to 1-partition RDD") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") {
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
val df1 = spark.range(10).withColumn("a", 'id)
val df2 = spark.range(10).withColumn("b", 'id)
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
@ -1233,7 +1234,7 @@ class AdaptiveQueryExecSuite
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
// This test is a copy of test(SPARK-32573), in order to test the configuration
// `spark.sql.adaptive.optimizer.excludedRules` works as expect.
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> EliminateUnnecessaryJoin.ruleName) {
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)")
val bhj = findTopLevelBroadcastHashJoin(plan)
@ -1307,6 +1308,71 @@ class AdaptiveQueryExecSuite
}
}
test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " +
"- single join") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
// left semi join and empty left side
("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT SEMI JOIN " +
"testData2 t2 ON t1.key = t2.a", true),
// left anti join and empty left side
("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT ANTI JOIN " +
"testData2 t2 ON t1.key = t2.a", true),
// left outer join and empty left side
("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 LEFT JOIN testData2 t2 ON " +
"t1.key = t2.a", true),
// left outer join and non-empty left side
("SELECT * FROM testData t1 LEFT JOIN testData2 t2 ON " +
"t1.key = t2.a", false),
// right outer join and empty right side
("SELECT * FROM testData t1 RIGHT JOIN (SELECT * FROM testData2 WHERE b = 0)t2 ON " +
"t1.key = t2.a", true),
// right outer join and non-empty right side
("SELECT * FROM testData t1 RIGHT JOIN testData2 t2 ON " +
"t1.key = t2.a", false),
// full outer join and both side empty
("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " +
"(SELECT * FROM testData2 WHERE b = 0)t2 ON t1.key = t2.a", true),
// full outer join and left side empty right side non-empty
("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " +
"testData2 t2 ON t1.key = t2.a", true)
).foreach { case (query, isEliminated) =>
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
assert(findTopLevelBaseJoin(plan).size == 1)
assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated, adaptivePlan)
}
}
}
test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " +
"- multi join") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
"""
|SELECT * FROM testData t1
| JOIN (SELECT * FROM testData2 WHERE b = 0) t2 ON t1.key = t2.a
| LEFT JOIN testData2 t3 ON t1.key = t3.a
|""".stripMargin,
"""
|SELECT * FROM (SELECT * FROM testData WHERE key = 0) t1
| LEFT ANTI JOIN testData2 t2
| FULL JOIN (SELECT * FROM testData2 WHERE b = 0) t3 ON t1.key = t3.a
|""".stripMargin,
"""
|SELECT * FROM testData t1
| LEFT SEMI JOIN (SELECT * FROM testData2 WHERE b = 0)
| RIGHT JOIN testData2 t3 on t1.key = t3.a
|""".stripMargin
).foreach { query =>
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
assert(findTopLevelBaseJoin(plan).size == 2)
assert(findTopLevelBaseJoin(adaptivePlan).isEmpty)
}
}
}
test("SPARK-32753: Only copy tags to node with no tags") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
withTempView("v1") {