diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 46be6b4256..3db5457de8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -386,6 +386,13 @@ package object dsl { condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE) + def lateralJoin( + otherPlan: LogicalPlan, + joinType: JoinType = Inner, + condition: Option[Expression] = None): LogicalPlan = { + LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition) + } + def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder]( otherPlan: LogicalPlan, func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d157330142..0c7452a37d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -126,13 +126,15 @@ object SubExprUtils extends PredicateHelper { /** * Returns an expression after removing the OuterReference shell. */ - def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } + def stripOuterReference[E <: Expression](e: E): E = { + e.transform { case OuterReference(r) => r }.asInstanceOf[E] + } /** * Returns the list of expressions after removing the OuterReference shell from each of * the expression. */ - def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + def stripOuterReferences[E <: Expression](e: Seq[E]): Seq[E] = e.map(stripOuterReference) /** * Returns the logical plan after removing the OuterReference shell from all the expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index f30dd9949f..71f3897ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -156,6 +156,23 @@ object DecorrelateInnerQuery extends PredicateHelper { expressions.map(replaceOuterReference(_, outerReferenceMap)) } + /** + * Replace all outer references in the given named expressions and keep the output + * attributes unchanged. + */ + private def replaceOuterInNamedExpressions( + expressions: Seq[NamedExpression], + outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = { + expressions.map { expr => + val newExpr = replaceOuterReference(expr, outerReferenceMap) + if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) { + Alias(newExpr, expr.name)(expr.exprId) + } else { + newExpr + } + } + } + /** * Return all references that are presented in the join conditions but not in the output * of the given named expressions. @@ -429,8 +446,9 @@ object DecorrelateInnerQuery extends PredicateHelper { val newOuterReferences = parentOuterReferences ++ outerReferences val (newChild, joinCond, outerReferenceMap) = decorrelate(child, newOuterReferences, aggregated) - // Replace all outer references in the original project list. - val newProjectList = replaceOuterReferences(projectList, outerReferenceMap) + // Replace all outer references in the original project list and keep the output + // attributes unchanged. + val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap) // Preserve required domain attributes in the join condition by adding the missing // references to the new project list. val referencesToAdd = missingReferences(newProjectList, joinCond) @@ -442,9 +460,10 @@ object DecorrelateInnerQuery extends PredicateHelper { val newOuterReferences = parentOuterReferences ++ outerReferences val (newChild, joinCond, outerReferenceMap) = decorrelate(child, newOuterReferences, aggregated = true) - // Replace all outer references in grouping and aggregate expressions. + // Replace all outer references in grouping and aggregate expressions, and keep + // the output attributes unchanged. val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap) - val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap) + val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap) // Add all required domain attributes to both grouping and aggregate expressions. val referencesToAdd = missingReferences(newAggExpr, joinCond) val newAggregate = a.copy( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aa2221b398..d30481ffc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -179,6 +179,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // non-nullable when an empty relation child of a Union is removed UpdateAttributeNullability) :: Batch("Pullup Correlated Expressions", Once, + OptimizeOneRowRelationSubquery, PullupCorrelatedPredicates) :: // Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense // to enforce idempotence on it and we change this batch from Once to FixedPoint(1). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 53448fbe92..6d6b8b7d8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -711,3 +712,47 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] { Join(left, newRight, joinType, newCond, JoinHint.NONE) } } + +/** + * This rule optimizes subqueries with OneRowRelation as leaf nodes. + */ +object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { + + object OneRowSubquery { + def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = { + CollapseProject(EliminateSubqueryAliases(plan)) match { + case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList)) + case _ => None + } + } + } + + private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = { + plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined + } + + /** + * Rewrite a subquery expression into one or more expressions. The rewrite can only be done + * if there is no nested subqueries in the subquery plan. + */ + private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { + case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None) + if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty => + Project(left.output ++ projectList, left) + case p: LogicalPlan => p.transformExpressionsUpWithPruning( + _.containsPattern(SCALAR_SUBQUERY)) { + case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _) + if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => + assert(projectList.size == 1) + projectList.head + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) { + plan + } else { + rewrite(plan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index db7fd5c3a1..3c9946ba37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -435,6 +435,23 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] subqueries ++ subqueries.flatMap(_.subqueriesAll) } + /** + * Returns a copy of this node where the given partial function has been recursively applied + * first to the subqueries in this node's children, then this node's children, and finally + * this node itself (post-order). When the partial function does not apply to a given node, + * it is left unchanged. + */ + def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformUp { case plan => + val transformed = plan transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + /** * A variant of `collect`. This method not only apply the given function to all elements in this * plan, also considering all the plans in its (nested) subqueries diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b3557ecf36..0b76b947a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2613,6 +2613,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY = + buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery") + .internal() + .doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala index 92995c2e85..b8886a5c0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest { val x = AttributeReference("x", IntegerType)() val y = AttributeReference("y", IntegerType)() val z = AttributeReference("z", IntegerType)() + val t0 = OneRowRelation() val testRelation = LocalRelation(a, b, c) val testRelation2 = LocalRelation(x, y, z) @@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest { test("correlated values in project") { val outerPlan = testRelation2 - val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation()) - val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation())) + val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0) + val correctAnswer = Project( + Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0)) check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) } test("correlated values in project with alias") { val outerPlan = testRelation2 val innerPlan = - Project(Seq(OuterReference(x), 'y1, 'sum), + Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum), Project(Seq( OuterReference(x), OuterReference(y).as("y1"), Add(OuterReference(x), OuterReference(y)).as("sum")), testRelation)).analyze val correctAnswer = - Project(Seq(x, 'y1, 'sum, y), - Project(Seq(x, y.as("y1"), (x + y).as("sum"), y), + Project(Seq(x.as("x1"), 'y1, 'sum, x, y), + Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y), DomainJoin(Seq(x, y), testRelation))).analyze check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) } @@ -228,13 +230,13 @@ class DecorrelateInnerQuerySuite extends PlanTest { val outerPlan = testRelation2 val innerPlan = Project( - Seq(OuterReference(x)), + Seq(OuterReference(x).as("x1")), Filter( And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), testRelation ) ) - val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation)) + val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation)) check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c)) } @@ -242,14 +244,14 @@ class DecorrelateInnerQuerySuite extends PlanTest { val outerPlan = testRelation2 val innerPlan = Project( - Seq(OuterReference(y)), + Seq(OuterReference(y).as("y1")), Filter( And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), testRelation ) ) val correctAnswer = - Project(Seq(y, a, c), + Project(Seq(y.as("y1"), y, a, c), Filter(b === 1, DomainJoin(Seq(y), testRelation) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala new file mode 100644 index 0000000000..4203859226 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.CleanupAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class OptimizeOneRowRelationSubquerySuite extends PlanTest { + + private var optimizeOneRowRelationSubqueryEnabled: Boolean = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + optimizeOneRowRelationSubqueryEnabled = + SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) + SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY, true) + } + + protected override def afterAll(): Unit = { + SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY, + optimizeOneRowRelationSubqueryEnabled) + super.afterAll() + } + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subquery", Once, + OptimizeOneRowRelationSubquery, + PullupCorrelatedPredicates) :: + Batch("Cleanup", FixedPoint(10), + CleanupAliases) :: Nil + } + + private def assertHasDomainJoin(plan: LogicalPlan): Unit = { + assert(plan.collectWithSubqueries { case d: DomainJoin => d }.nonEmpty, + s"Plan does not contain DomainJoin:\n$plan") + } + + val t0 = OneRowRelation() + val a = 'a.int + val b = 'b.int + val t1 = LocalRelation(a, b) + val t2 = LocalRelation('c.int, 'd.int) + + test("Optimize scalar subquery with a single project") { + // SELECT (SELECT a) FROM t1 + val query = t1.select(ScalarSubquery(t0.select('a)).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a.as("sub")) + comparePlans(optimized, correctAnswer.analyze) + } + + test("Optimize lateral subquery with a single project") { + Seq(Inner, LeftOuter, Cross).foreach { joinType => + // SELECT * FROM t1 JOIN LATERAL (SELECT a, b) + val query = t1.lateralJoin(t0.select('a, 'b), joinType, None) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a, 'b, 'a.as("a"), 'b.as("b")) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("Optimize subquery with subquery alias") { + val inner = t0.select('a).as("t2") + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a.as("sub")) + comparePlans(optimized, correctAnswer.analyze) + } + + test("Optimize scalar subquery with multiple projects") { + // SELECT (SELECT a1 + b1 FROM (SELECT a AS a1, b AS b1)) FROM t1 + val inner = t0.select('a.as("a1"), 'b.as("b1")).select(('a1 + 'b1).as("c")) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select(('a + 'b).as("c").as("sub")) + comparePlans(optimized, correctAnswer.analyze) + } + + test("Optimize lateral subquery with multiple projects") { + Seq(Inner, LeftOuter, Cross).foreach { joinType => + val inner = t0.select('a.as("a1"), 'b.as("b1")) + .select(('a1 + 'b1).as("c1"), ('a1 - 'b1).as("c2")) + val query = t1.lateralJoin(inner, joinType, None) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a, 'b, ('a + 'b).as("c1"), ('a - 'b).as("c2")) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("Optimize subquery with nested correlated subqueries") { + // SELECT (SELECT (SELECT b) FROM (SELECT a AS b)) FROM t1 + val inner = t0.select('a.as("b")).select(ScalarSubquery(t0.select('b)).as("s")) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a.as("s").as("sub")) + comparePlans(optimized, correctAnswer.analyze) + } + + test("Batch should be idempotent") { + // SELECT (SELECT 1 WHERE a = a + 1) FROM t1 + val inner = t0.select(1).where('a === 'a + 1) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val doubleOptimized = Optimize.execute(optimized) + comparePlans(optimized, doubleOptimized, checkAnalysis = false) + } + + test("Should not optimize scalar subquery with operators other than project") { + // SELECT (SELECT a AS a1 WHERE a = 1) FROM t1 + val inner = t0.where('a === 1).select('a.as("a1")) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } + + test("Should not optimize subquery with non-deterministic expressions") { + // SELECT (SELECT r FROM (SELECT a + rand() AS r)) FROM t1 + val inner = t0.select(('a + rand(0)).as("r")).select('r) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } + + test("Should not optimize lateral join with non-empty join conditions") { + Seq(Inner, LeftOuter).foreach { joinType => + // SELECT * FROM t1 JOIN LATERAL (SELECT a AS a1, b AS b1) ON a = b1 + val query = t1.lateralJoin(t0.select('a.as("a1"), 'b.as("b1")), joinType, Some('a === 'b1)) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } + } + + test("Should not optimize subquery with nested subqueries that can't be optimized") { + // SELECT (SELECT (SELECT a WHERE a = 1) FROM (SELECT a AS a)) FROM t1 + // Filter (a = 1) cannot be optimized. + val inner = t0.select('a).where('a === 1) + val subquery = t0.select('a.as("a")) + .select(ScalarSubquery(inner).as("s")).select('s + 1) + val query = t1.select(ScalarSubquery(subquery).as("sub")) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e06af08147..c3362b377e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1838,7 +1838,8 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("Subquery reuse across the whole plan") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "false") { val df = sql( """ |SELECT (SELECT avg(key) FROM testData), (SELECT (SELECT avg(key) FROM testData))