[SPARK-36063][SQL] Optimize OneRowRelation subqueries

### What changes were proposed in this pull request?
This PR adds optimization for scalar and lateral subqueries with OneRowRelation as leaf nodes. It inlines such subqueries before decorrelation to avoid rewriting them as left outer joins. It also introduces a flag to turn on/off this optimization: `spark.sql.optimizer.optimizeOneRowRelationSubquery` (default: True).

For example:
```sql
select (select c1) from t
```
Analyzed plan:
```
Project [scalar-subquery#17 [c1#18] AS scalarsubquery(c1)#22]
:  +- Project [outer(c1#18)]
:     +- OneRowRelation
+- LocalRelation [c1#18, c2#19]
```

Optimized plan before this PR:
```
Project [c1#18#25 AS scalarsubquery(c1)#22]
+- Join LeftOuter, (c1#24 <=> c1#18)
   :- LocalRelation [c1#18]
   +- Aggregate [c1#18], [c1#18 AS c1#18#25, c1#18 AS c1#24]
      +- LocalRelation [c1#18]
```

Optimized plan after this PR:
```
LocalRelation [scalarsubquery(c1)#22]
```

### Why are the changes needed?
To optimize query plans.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added new unit tests.

Closes #33284 from allisonwang-db/spark-36063-optimize-subquery-one-row-relation.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit de8e4be92c)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
allisonwang-db 2021-07-22 10:48:32 +08:00 committed by Wenchen Fan
parent d01e53208b
commit 31bb9e04ad
10 changed files with 283 additions and 16 deletions

View file

@ -386,6 +386,13 @@ package object dsl {
condition: Option[Expression] = None): LogicalPlan =
Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)
def lateralJoin(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
condition: Option[Expression] = None): LogicalPlan = {
LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition)
}
def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder](
otherPlan: LogicalPlan,
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],

View file

@ -126,13 +126,15 @@ object SubExprUtils extends PredicateHelper {
/**
* Returns an expression after removing the OuterReference shell.
*/
def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r }
def stripOuterReference[E <: Expression](e: E): E = {
e.transform { case OuterReference(r) => r }.asInstanceOf[E]
}
/**
* Returns the list of expressions after removing the OuterReference shell from each of
* the expression.
*/
def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference)
def stripOuterReferences[E <: Expression](e: Seq[E]): Seq[E] = e.map(stripOuterReference)
/**
* Returns the logical plan after removing the OuterReference shell from all the expressions

View file

@ -156,6 +156,23 @@ object DecorrelateInnerQuery extends PredicateHelper {
expressions.map(replaceOuterReference(_, outerReferenceMap))
}
/**
* Replace all outer references in the given named expressions and keep the output
* attributes unchanged.
*/
private def replaceOuterInNamedExpressions(
expressions: Seq[NamedExpression],
outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = {
expressions.map { expr =>
val newExpr = replaceOuterReference(expr, outerReferenceMap)
if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) {
Alias(newExpr, expr.name)(expr.exprId)
} else {
newExpr
}
}
}
/**
* Return all references that are presented in the join conditions but not in the output
* of the given named expressions.
@ -429,8 +446,9 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated)
// Replace all outer references in the original project list.
val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
// Replace all outer references in the original project list and keep the output
// attributes unchanged.
val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap)
// Preserve required domain attributes in the join condition by adding the missing
// references to the new project list.
val referencesToAdd = missingReferences(newProjectList, joinCond)
@ -442,9 +460,10 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated = true)
// Replace all outer references in grouping and aggregate expressions.
// Replace all outer references in grouping and aggregate expressions, and keep
// the output attributes unchanged.
val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap)
val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap)
// Add all required domain attributes to both grouping and aggregate expressions.
val referencesToAdd = missingReferences(newAggExpr, joinCond)
val newAggregate = a.copy(

View file

@ -179,6 +179,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// non-nullable when an empty relation child of a Union is removed
UpdateAttributeNullability) ::
Batch("Pullup Correlated Expressions", Once,
OptimizeOneRowRelationSubquery,
PullupCorrelatedPredicates) ::
// Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense
// to enforce idempotence on it and we change this batch from Once to FixedPoint(1).

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
@ -711,3 +712,47 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] {
Join(left, newRight, joinType, newCond, JoinHint.NONE)
}
}
/**
* This rule optimizes subqueries with OneRowRelation as leaf nodes.
*/
object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {
object OneRowSubquery {
def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = {
CollapseProject(EliminateSubqueryAliases(plan)) match {
case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList))
case _ => None
}
}
}
private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = {
plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined
}
/**
* Rewrite a subquery expression into one or more expressions. The rewrite can only be done
* if there is no nested subqueries in the subquery plan.
*/
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty =>
Project(left.output ++ projectList, left)
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(projectList.size == 1)
projectList.head
}
}
def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) {
plan
} else {
rewrite(plan)
}
}
}

View file

@ -435,6 +435,23 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
subqueries ++ subqueries.flatMap(_.subqueriesAll)
}
/**
* Returns a copy of this node where the given partial function has been recursively applied
* first to the subqueries in this node's children, then this node's children, and finally
* this node itself (post-order). When the partial function does not apply to a given node,
* it is left unchanged.
*/
def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
transformUp { case plan =>
val transformed = plan transformExpressionsUp {
case planExpression: PlanExpression[PlanType] =>
val newPlan = planExpression.plan.transformUpWithSubqueries(f)
planExpression.withNewPlan(newPlan)
}
f.applyOrElse[PlanType, PlanType](transformed, identity)
}
}
/**
* A variant of `collect`. This method not only apply the given function to all elements in this
* plan, also considering all the plans in its (nested) subqueries

View file

@ -2613,6 +2613,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY =
buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery")
.internal()
.doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.")
.version("3.2.0")
.booleanConf
.createWithDefault(true)
val TOP_K_SORT_FALLBACK_THRESHOLD =
buildConf("spark.sql.execution.topKSortFallbackThreshold")
.internal()

View file

@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()
val z = AttributeReference("z", IntegerType)()
val t0 = OneRowRelation()
val testRelation = LocalRelation(a, b, c)
val testRelation2 = LocalRelation(x, y, z)
@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest {
test("correlated values in project") {
val outerPlan = testRelation2
val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation())
val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation()))
val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0)
val correctAnswer = Project(
Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}
test("correlated values in project with alias") {
val outerPlan = testRelation2
val innerPlan =
Project(Seq(OuterReference(x), 'y1, 'sum),
Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum),
Project(Seq(
OuterReference(x),
OuterReference(y).as("y1"),
Add(OuterReference(x), OuterReference(y)).as("sum")),
testRelation)).analyze
val correctAnswer =
Project(Seq(x, 'y1, 'sum, y),
Project(Seq(x, y.as("y1"), (x + y).as("sum"), y),
Project(Seq(x.as("x1"), 'y1, 'sum, x, y),
Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y),
DomainJoin(Seq(x, y), testRelation))).analyze
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}
@ -228,13 +230,13 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val outerPlan = testRelation2
val innerPlan =
Project(
Seq(OuterReference(x)),
Seq(OuterReference(x).as("x1")),
Filter(
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
testRelation
)
)
val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation))
val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation))
check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c))
}
@ -242,14 +244,14 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val outerPlan = testRelation2
val innerPlan =
Project(
Seq(OuterReference(y)),
Seq(OuterReference(y).as("y1")),
Filter(
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
testRelation
)
)
val correctAnswer =
Project(Seq(y, a, c),
Project(Seq(y.as("y1"), y, a, c),
Filter(b === 1,
DomainJoin(Seq(y), testRelation)
)

View file

@ -0,0 +1,165 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
class OptimizeOneRowRelationSubquerySuite extends PlanTest {
private var optimizeOneRowRelationSubqueryEnabled: Boolean = _
protected override def beforeAll(): Unit = {
super.beforeAll()
optimizeOneRowRelationSubqueryEnabled =
SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)
SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY, true)
}
protected override def afterAll(): Unit = {
SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY,
optimizeOneRowRelationSubqueryEnabled)
super.afterAll()
}
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subquery", Once,
OptimizeOneRowRelationSubquery,
PullupCorrelatedPredicates) ::
Batch("Cleanup", FixedPoint(10),
CleanupAliases) :: Nil
}
private def assertHasDomainJoin(plan: LogicalPlan): Unit = {
assert(plan.collectWithSubqueries { case d: DomainJoin => d }.nonEmpty,
s"Plan does not contain DomainJoin:\n$plan")
}
val t0 = OneRowRelation()
val a = 'a.int
val b = 'b.int
val t1 = LocalRelation(a, b)
val t2 = LocalRelation('c.int, 'd.int)
test("Optimize scalar subquery with a single project") {
// SELECT (SELECT a) FROM t1
val query = t1.select(ScalarSubquery(t0.select('a)).as("sub"))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = t1.select('a.as("sub"))
comparePlans(optimized, correctAnswer.analyze)
}
test("Optimize lateral subquery with a single project") {
Seq(Inner, LeftOuter, Cross).foreach { joinType =>
// SELECT * FROM t1 JOIN LATERAL (SELECT a, b)
val query = t1.lateralJoin(t0.select('a, 'b), joinType, None)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = t1.select('a, 'b, 'a.as("a"), 'b.as("b"))
comparePlans(optimized, correctAnswer.analyze)
}
}
test("Optimize subquery with subquery alias") {
val inner = t0.select('a).as("t2")
val query = t1.select(ScalarSubquery(inner).as("sub"))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = t1.select('a.as("sub"))
comparePlans(optimized, correctAnswer.analyze)
}
test("Optimize scalar subquery with multiple projects") {
// SELECT (SELECT a1 + b1 FROM (SELECT a AS a1, b AS b1)) FROM t1
val inner = t0.select('a.as("a1"), 'b.as("b1")).select(('a1 + 'b1).as("c"))
val query = t1.select(ScalarSubquery(inner).as("sub"))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = t1.select(('a + 'b).as("c").as("sub"))
comparePlans(optimized, correctAnswer.analyze)
}
test("Optimize lateral subquery with multiple projects") {
Seq(Inner, LeftOuter, Cross).foreach { joinType =>
val inner = t0.select('a.as("a1"), 'b.as("b1"))
.select(('a1 + 'b1).as("c1"), ('a1 - 'b1).as("c2"))
val query = t1.lateralJoin(inner, joinType, None)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = t1.select('a, 'b, ('a + 'b).as("c1"), ('a - 'b).as("c2"))
comparePlans(optimized, correctAnswer.analyze)
}
}
test("Optimize subquery with nested correlated subqueries") {
// SELECT (SELECT (SELECT b) FROM (SELECT a AS b)) FROM t1
val inner = t0.select('a.as("b")).select(ScalarSubquery(t0.select('b)).as("s"))
val query = t1.select(ScalarSubquery(inner).as("sub"))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = t1.select('a.as("s").as("sub"))
comparePlans(optimized, correctAnswer.analyze)
}
test("Batch should be idempotent") {
// SELECT (SELECT 1 WHERE a = a + 1) FROM t1
val inner = t0.select(1).where('a === 'a + 1)
val query = t1.select(ScalarSubquery(inner).as("sub"))
val optimized = Optimize.execute(query.analyze)
val doubleOptimized = Optimize.execute(optimized)
comparePlans(optimized, doubleOptimized, checkAnalysis = false)
}
test("Should not optimize scalar subquery with operators other than project") {
// SELECT (SELECT a AS a1 WHERE a = 1) FROM t1
val inner = t0.where('a === 1).select('a.as("a1"))
val query = t1.select(ScalarSubquery(inner).as("sub"))
val optimized = Optimize.execute(query.analyze)
assertHasDomainJoin(optimized)
}
test("Should not optimize subquery with non-deterministic expressions") {
// SELECT (SELECT r FROM (SELECT a + rand() AS r)) FROM t1
val inner = t0.select(('a + rand(0)).as("r")).select('r)
val query = t1.select(ScalarSubquery(inner).as("sub"))
val optimized = Optimize.execute(query.analyze)
assertHasDomainJoin(optimized)
}
test("Should not optimize lateral join with non-empty join conditions") {
Seq(Inner, LeftOuter).foreach { joinType =>
// SELECT * FROM t1 JOIN LATERAL (SELECT a AS a1, b AS b1) ON a = b1
val query = t1.lateralJoin(t0.select('a.as("a1"), 'b.as("b1")), joinType, Some('a === 'b1))
val optimized = Optimize.execute(query.analyze)
assertHasDomainJoin(optimized)
}
}
test("Should not optimize subquery with nested subqueries that can't be optimized") {
// SELECT (SELECT (SELECT a WHERE a = 1) FROM (SELECT a AS a)) FROM t1
// Filter (a = 1) cannot be optimized.
val inner = t0.select('a).where('a === 1)
val subquery = t0.select('a.as("a"))
.select(ScalarSubquery(inner).as("s")).select('s + 1)
val query = t1.select(ScalarSubquery(subquery).as("sub"))
val optimized = Optimize.execute(query.analyze)
assertHasDomainJoin(optimized)
}
}

View file

@ -1838,7 +1838,8 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("Subquery reuse across the whole plan") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "false") {
val df = sql(
"""
|SELECT (SELECT avg(key) FROM testData), (SELECT (SELECT avg(key) FROM testData))