[SPARK-36028][SQL] Allow Project to host outer references in scalar subqueries

### What changes were proposed in this pull request?
This PR allows the `Project` node to host outer references in scalar subqueries when `decorrelateInnerQuery` is enabled. It is already supported by the new decorrelation framework and the `RewriteCorrelatedScalarSubquery` rule.

Note currently by default all correlated subqueries will be decorrelated, which is not necessarily the most optimal approach. Consider `SELECT (SELECT c1) FROM t`. This should be optimized as `SELECT c1 FROM t` instead of rewriting it as a left outer join. This will be done in a separate PR to optimize correlated scalar/lateral subqueries with OneRowRelation.

### Why are the changes needed?
To allow more types of correlated scalar subqueries.

### Does this PR introduce _any_ user-facing change?
Yes. This PR allows outer query column references in the SELECT cluase of a correlated scalar subquery. For example:
```sql
SELECT (SELECT c1) FROM t;
```
Before this change:
```
org.apache.spark.sql.AnalysisException: Expressions referencing the outer query are not supported
outside of WHERE/HAVING clauses
```

After this change:
```
+------------------+
|scalarsubquery(c1)|
+------------------+
|0                 |
|1                 |
+------------------+
```

### How was this patch tested?
Added unit tests and SQL tests.

Closes #33235 from allisonwang-db/spark-36028-outer-in-project.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
allisonwang-db 2021-07-07 04:25:54 +00:00 committed by Wenchen Fan
parent bad6f89ae2
commit ca348e50a4
5 changed files with 144 additions and 18 deletions

View file

@ -725,9 +725,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
s"Filter/Aggregate/Project and a few commands: $plan") s"Filter/Aggregate/Project and a few commands: $plan")
} }
} }
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)
case _: LateralSubquery => case _: LateralSubquery =>
assert(plan.isInstanceOf[LateralJoin]) assert(plan.isInstanceOf[LateralJoin])
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)
case inSubqueryOrExistsSubquery => case inSubqueryOrExistsSubquery =>
plan match { plan match {
@ -736,11 +742,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" + failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
s" Filter/Join and a few commands: $plan") s" Filter/Join and a few commands: $plan")
} }
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan)
} }
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isLateral = plan.isInstanceOf[LateralJoin])
} }
/** /**
@ -779,7 +784,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
* Validates to make sure the outer references appearing inside the subquery * Validates to make sure the outer references appearing inside the subquery
* are allowed. * are allowed.
*/ */
private def checkCorrelationsInSubquery(sub: LogicalPlan, isLateral: Boolean = false): Unit = { private def checkCorrelationsInSubquery(
sub: LogicalPlan,
isScalarOrLateral: Boolean = false): Unit = {
// Validate that correlated aggregate expression do not contain a mixture // Validate that correlated aggregate expression do not contain a mixture
// of outer and local references. // of outer and local references.
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
@ -800,11 +807,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
} }
// Check whether the logical plan node can host outer references. // Check whether the logical plan node can host outer references.
// A `Project` can host outer references if it is inside a lateral subquery. // A `Project` can host outer references if it is inside a scalar or a lateral subquery and
// Otherwise, only Filter can only outer references. // DecorrelateInnerQuery is enabled. Otherwise, only Filter can only outer references.
def canHostOuter(plan: LogicalPlan): Boolean = plan match { def canHostOuter(plan: LogicalPlan): Boolean = plan match {
case _: Filter => true case _: Filter => true
case _: Project => isLateral case _: Project => isScalarOrLateral && SQLConf.get.decorrelateInnerQueryEnabled
case _ => false case _ => false
} }

View file

@ -824,13 +824,6 @@ class AnalysisErrorSuite extends AnalysisTest {
Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1), Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1),
"Scalar subquery must return only one column, but got 2" :: Nil) "Scalar subquery must return only one column, but got 2" :: Nil)
// array(t1.*) in the subquery should be resolved into array(outer(t1.a), outer(t1.b))
val array = CreateArray(Seq(star("t1")))
assertAnalysisError(
Project(ScalarSubquery(t0.select(array)).as("sub") :: Nil, t1),
"Expressions referencing the outer query are not supported outside" +
" of WHERE/HAVING clauses" :: Nil)
// t2.* cannot be resolved and the error should be the initial analysis exception. // t2.* cannot be resolved and the error should be the initial analysis exception.
assertAnalysisError( assertAnalysisError(
Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1), Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1),

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference} import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
@ -240,4 +240,28 @@ class ResolveSubquerySuite extends AnalysisTest {
Inner, None) Inner, None)
) )
} }
test("SPARK-36028: resolve scalar subqueries with outer references in Project") {
// SELECT (SELECT a) FROM t1
checkAnalysis(
Project(ScalarSubquery(t0.select('a)).as("sub") :: Nil, t1),
Project(ScalarSubquery(Project(OuterReference(a) :: Nil, t0), Seq(a)).as("sub") :: Nil, t1)
)
// SELECT (SELECT a + b + c AS r FROM t2) FROM t1
checkAnalysis(
Project(ScalarSubquery(
t2.select(('a + 'b + 'c).as("r"))).as("sub") :: Nil, t1),
Project(ScalarSubquery(
Project((OuterReference(a) + b + c).as("r") :: Nil, t2), Seq(a)).as("sub") :: Nil, t1)
)
// SELECT (SELECT array(t1.*) AS arr) FROM t1
checkAnalysis(
Project(ScalarSubquery(t0.select(
CreateArray(Seq(star("t1"))).as("arr"))
).as("sub") :: Nil, t1.as("t1")),
Project(ScalarSubquery(Project(
CreateArray(Seq(OuterReference(a), OuterReference(b))).as("arr") :: Nil, t0
), Seq(a, b)).as("sub") :: Nil, t1)
)
}
} }

View file

@ -137,4 +137,11 @@ SELECT t1a,
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2, (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
(SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2, (SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2,
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2 (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
FROM t1; FROM t1;
-- SPARK-36028: Allow Project to host outer references in scalar subqueries
SELECT t1c, (SELECT t1c) FROM t1;
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1;
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1;
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1;
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1;

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite -- Automatically generated by SQLQueryTestSuite
-- Number of queries: 12 -- Number of queries: 17
-- !query -- !query
@ -222,3 +222,98 @@ val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB9000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
-- !query
SELECT t1c, (SELECT t1c) FROM t1
-- !query schema
struct<t1c:int,scalarsubquery(t1c):int>
-- !query output
12 12
12 12
16 16
16 16
16 16
16 16
8 8
8 8
NULL NULL
NULL NULL
NULL NULL
NULL NULL
-- !query
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1
-- !query schema
struct<t1c:int,scalarsubquery(t1c, t1c):int>
-- !query output
12 NULL
12 NULL
16 NULL
16 NULL
16 NULL
16 NULL
8 8
8 8
NULL NULL
NULL NULL
NULL NULL
NULL NULL
-- !query
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1
-- !query schema
struct<t1c:int,t1d:bigint,scalarsubquery(t1c, t1d):bigint>
-- !query output
12 10 22
12 21 33
16 19 35
16 19 35
16 19 35
16 22 38
8 10 18
8 10 18
NULL 12 NULL
NULL 19 NULL
NULL 19 NULL
NULL 25 NULL
-- !query
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1
-- !query schema
struct<t1c:int,scalarsubquery(t1c):bigint>
-- !query output
12 12
12 12
16 16
16 16
16 16
16 16
8 8
8 8
NULL NULL
NULL NULL
NULL NULL
NULL NULL
-- !query
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1
-- !query schema
struct<t1a:string,scalarsubquery(t1a):bigint>
-- !query output
val1a NULL
val1a NULL
val1a NULL
val1a NULL
val1b 36
val1c 24
val1d NULL
val1d NULL
val1d NULL
val1e 8
val1e 8
val1e 8