diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index baa2540252..c49f2168ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -100,6 +100,7 @@ object TreePattern extends Enumeration { val LEFT_SEMI_OR_ANTI_JOIN: Value = Value val LIMIT: Value = Value val LOCAL_RELATION: Value = Value + val LOGICAL_QUERY_STAGE: Value = Value val NATURAL_LIKE_JOIN: Value = Value val OUTER_JOIN: Value = Value val PROJECT: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index 648d2e7117..ea2fb1c313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys /** @@ -53,8 +54,13 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { empty(j) } - // TODO we need use transformUpWithPruning instead of transformUp - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at + // `PropagateEmptyRelationBase.commonApplyFunc` + // LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc` + // and `AQEPropagateEmptyRelation.eliminateSingleColumnNullAwareAntiJoin` + // Note that, We can not specify ruleId here since the LogicalQueryStage is not immutable. + _.containsAnyPattern(LOGICAL_QUERY_STAGE, LOCAL_RELATION, TRUE_OR_FALSE_LITERAL)) { eliminateSingleColumnNullAwareAntiJoin.orElse(commonApplyFunc) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala index bff142315f..8bb3708390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LOGICAL_QUERY_STAGE, TreePattern} import org.apache.spark.sql.execution.SparkPlan /** @@ -39,6 +40,7 @@ case class LogicalQueryStage( override def output: Seq[Attribute] = logicalPlan.output override val isStreaming: Boolean = logicalPlan.isStreaming override val outputOrdering: Seq[SortOrder] = physicalPlan.outputOrdering + override protected val nodePatterns: Seq[TreePattern] = Seq(LOGICAL_QUERY_STAGE) override def computeStats(): Statistics = { // TODO this is not accurate when there is other physical nodes above QueryStageExec.