[SPARK-36028][SQL] Allow Project to host outer references in scalar subqueries
### What changes were proposed in this pull request? This PR allows the `Project` node to host outer references in scalar subqueries when `decorrelateInnerQuery` is enabled. It is already supported by the new decorrelation framework and the `RewriteCorrelatedScalarSubquery` rule. Note currently by default all correlated subqueries will be decorrelated, which is not necessarily the most optimal approach. Consider `SELECT (SELECT c1) FROM t`. This should be optimized as `SELECT c1 FROM t` instead of rewriting it as a left outer join. This will be done in a separate PR to optimize correlated scalar/lateral subqueries with OneRowRelation. ### Why are the changes needed? To allow more types of correlated scalar subqueries. ### Does this PR introduce _any_ user-facing change? Yes. This PR allows outer query column references in the SELECT cluase of a correlated scalar subquery. For example: ```sql SELECT (SELECT c1) FROM t; ``` Before this change: ``` org.apache.spark.sql.AnalysisException: Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses ``` After this change: ``` +------------------+ |scalarsubquery(c1)| +------------------+ |0 | |1 | +------------------+ ``` ### How was this patch tested? Added unit tests and SQL tests. Closes #33235 from allisonwang-db/spark-36028-outer-in-project. Authored-by: allisonwang-db <allison.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
bad6f89ae2
commit
ca348e50a4
|
@ -725,9 +725,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
||||||
s"Filter/Aggregate/Project and a few commands: $plan")
|
s"Filter/Aggregate/Project and a few commands: $plan")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Validate to make sure the correlations appearing in the query are valid and
|
||||||
|
// allowed by spark.
|
||||||
|
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)
|
||||||
|
|
||||||
case _: LateralSubquery =>
|
case _: LateralSubquery =>
|
||||||
assert(plan.isInstanceOf[LateralJoin])
|
assert(plan.isInstanceOf[LateralJoin])
|
||||||
|
// Validate to make sure the correlations appearing in the query are valid and
|
||||||
|
// allowed by spark.
|
||||||
|
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)
|
||||||
|
|
||||||
case inSubqueryOrExistsSubquery =>
|
case inSubqueryOrExistsSubquery =>
|
||||||
plan match {
|
plan match {
|
||||||
|
@ -736,11 +742,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
||||||
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
|
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
|
||||||
s" Filter/Join and a few commands: $plan")
|
s" Filter/Join and a few commands: $plan")
|
||||||
}
|
}
|
||||||
|
// Validate to make sure the correlations appearing in the query are valid and
|
||||||
|
// allowed by spark.
|
||||||
|
checkCorrelationsInSubquery(expr.plan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate to make sure the correlations appearing in the query are valid and
|
|
||||||
// allowed by spark.
|
|
||||||
checkCorrelationsInSubquery(expr.plan, isLateral = plan.isInstanceOf[LateralJoin])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -779,7 +784,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
||||||
* Validates to make sure the outer references appearing inside the subquery
|
* Validates to make sure the outer references appearing inside the subquery
|
||||||
* are allowed.
|
* are allowed.
|
||||||
*/
|
*/
|
||||||
private def checkCorrelationsInSubquery(sub: LogicalPlan, isLateral: Boolean = false): Unit = {
|
private def checkCorrelationsInSubquery(
|
||||||
|
sub: LogicalPlan,
|
||||||
|
isScalarOrLateral: Boolean = false): Unit = {
|
||||||
// Validate that correlated aggregate expression do not contain a mixture
|
// Validate that correlated aggregate expression do not contain a mixture
|
||||||
// of outer and local references.
|
// of outer and local references.
|
||||||
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
|
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
|
||||||
|
@ -800,11 +807,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether the logical plan node can host outer references.
|
// Check whether the logical plan node can host outer references.
|
||||||
// A `Project` can host outer references if it is inside a lateral subquery.
|
// A `Project` can host outer references if it is inside a scalar or a lateral subquery and
|
||||||
// Otherwise, only Filter can only outer references.
|
// DecorrelateInnerQuery is enabled. Otherwise, only Filter can only outer references.
|
||||||
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
|
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
|
||||||
case _: Filter => true
|
case _: Filter => true
|
||||||
case _: Project => isLateral
|
case _: Project => isScalarOrLateral && SQLConf.get.decorrelateInnerQueryEnabled
|
||||||
case _ => false
|
case _ => false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -824,13 +824,6 @@ class AnalysisErrorSuite extends AnalysisTest {
|
||||||
Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1),
|
Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1),
|
||||||
"Scalar subquery must return only one column, but got 2" :: Nil)
|
"Scalar subquery must return only one column, but got 2" :: Nil)
|
||||||
|
|
||||||
// array(t1.*) in the subquery should be resolved into array(outer(t1.a), outer(t1.b))
|
|
||||||
val array = CreateArray(Seq(star("t1")))
|
|
||||||
assertAnalysisError(
|
|
||||||
Project(ScalarSubquery(t0.select(array)).as("sub") :: Nil, t1),
|
|
||||||
"Expressions referencing the outer query are not supported outside" +
|
|
||||||
" of WHERE/HAVING clauses" :: Nil)
|
|
||||||
|
|
||||||
// t2.* cannot be resolved and the error should be the initial analysis exception.
|
// t2.* cannot be resolved and the error should be the initial analysis exception.
|
||||||
assertAnalysisError(
|
assertAnalysisError(
|
||||||
Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1),
|
Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1),
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
|
||||||
import org.apache.spark.sql.AnalysisException
|
import org.apache.spark.sql.AnalysisException
|
||||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference}
|
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery}
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
|
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
|
||||||
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
|
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
|
@ -240,4 +240,28 @@ class ResolveSubquerySuite extends AnalysisTest {
|
||||||
Inner, None)
|
Inner, None)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-36028: resolve scalar subqueries with outer references in Project") {
|
||||||
|
// SELECT (SELECT a) FROM t1
|
||||||
|
checkAnalysis(
|
||||||
|
Project(ScalarSubquery(t0.select('a)).as("sub") :: Nil, t1),
|
||||||
|
Project(ScalarSubquery(Project(OuterReference(a) :: Nil, t0), Seq(a)).as("sub") :: Nil, t1)
|
||||||
|
)
|
||||||
|
// SELECT (SELECT a + b + c AS r FROM t2) FROM t1
|
||||||
|
checkAnalysis(
|
||||||
|
Project(ScalarSubquery(
|
||||||
|
t2.select(('a + 'b + 'c).as("r"))).as("sub") :: Nil, t1),
|
||||||
|
Project(ScalarSubquery(
|
||||||
|
Project((OuterReference(a) + b + c).as("r") :: Nil, t2), Seq(a)).as("sub") :: Nil, t1)
|
||||||
|
)
|
||||||
|
// SELECT (SELECT array(t1.*) AS arr) FROM t1
|
||||||
|
checkAnalysis(
|
||||||
|
Project(ScalarSubquery(t0.select(
|
||||||
|
CreateArray(Seq(star("t1"))).as("arr"))
|
||||||
|
).as("sub") :: Nil, t1.as("t1")),
|
||||||
|
Project(ScalarSubquery(Project(
|
||||||
|
CreateArray(Seq(OuterReference(a), OuterReference(b))).as("arr") :: Nil, t0
|
||||||
|
), Seq(a, b)).as("sub") :: Nil, t1)
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,4 +137,11 @@ SELECT t1a,
|
||||||
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
|
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
|
||||||
(SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2,
|
(SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2,
|
||||||
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
|
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
|
||||||
FROM t1;
|
FROM t1;
|
||||||
|
|
||||||
|
-- SPARK-36028: Allow Project to host outer references in scalar subqueries
|
||||||
|
SELECT t1c, (SELECT t1c) FROM t1;
|
||||||
|
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1;
|
||||||
|
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1;
|
||||||
|
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1;
|
||||||
|
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
-- Automatically generated by SQLQueryTestSuite
|
-- Automatically generated by SQLQueryTestSuite
|
||||||
-- Number of queries: 12
|
-- Number of queries: 17
|
||||||
|
|
||||||
|
|
||||||
-- !query
|
-- !query
|
||||||
|
@ -222,3 +222,98 @@ val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB9000000
|
||||||
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
|
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
|
||||||
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
|
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
|
||||||
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
|
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
|
||||||
|
|
||||||
|
|
||||||
|
-- !query
|
||||||
|
SELECT t1c, (SELECT t1c) FROM t1
|
||||||
|
-- !query schema
|
||||||
|
struct<t1c:int,scalarsubquery(t1c):int>
|
||||||
|
-- !query output
|
||||||
|
12 12
|
||||||
|
12 12
|
||||||
|
16 16
|
||||||
|
16 16
|
||||||
|
16 16
|
||||||
|
16 16
|
||||||
|
8 8
|
||||||
|
8 8
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
|
||||||
|
|
||||||
|
-- !query
|
||||||
|
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1
|
||||||
|
-- !query schema
|
||||||
|
struct<t1c:int,scalarsubquery(t1c, t1c):int>
|
||||||
|
-- !query output
|
||||||
|
12 NULL
|
||||||
|
12 NULL
|
||||||
|
16 NULL
|
||||||
|
16 NULL
|
||||||
|
16 NULL
|
||||||
|
16 NULL
|
||||||
|
8 8
|
||||||
|
8 8
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
|
||||||
|
|
||||||
|
-- !query
|
||||||
|
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1
|
||||||
|
-- !query schema
|
||||||
|
struct<t1c:int,t1d:bigint,scalarsubquery(t1c, t1d):bigint>
|
||||||
|
-- !query output
|
||||||
|
12 10 22
|
||||||
|
12 21 33
|
||||||
|
16 19 35
|
||||||
|
16 19 35
|
||||||
|
16 19 35
|
||||||
|
16 22 38
|
||||||
|
8 10 18
|
||||||
|
8 10 18
|
||||||
|
NULL 12 NULL
|
||||||
|
NULL 19 NULL
|
||||||
|
NULL 19 NULL
|
||||||
|
NULL 25 NULL
|
||||||
|
|
||||||
|
|
||||||
|
-- !query
|
||||||
|
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1
|
||||||
|
-- !query schema
|
||||||
|
struct<t1c:int,scalarsubquery(t1c):bigint>
|
||||||
|
-- !query output
|
||||||
|
12 12
|
||||||
|
12 12
|
||||||
|
16 16
|
||||||
|
16 16
|
||||||
|
16 16
|
||||||
|
16 16
|
||||||
|
8 8
|
||||||
|
8 8
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
NULL NULL
|
||||||
|
|
||||||
|
|
||||||
|
-- !query
|
||||||
|
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1
|
||||||
|
-- !query schema
|
||||||
|
struct<t1a:string,scalarsubquery(t1a):bigint>
|
||||||
|
-- !query output
|
||||||
|
val1a NULL
|
||||||
|
val1a NULL
|
||||||
|
val1a NULL
|
||||||
|
val1a NULL
|
||||||
|
val1b 36
|
||||||
|
val1c 24
|
||||||
|
val1d NULL
|
||||||
|
val1d NULL
|
||||||
|
val1d NULL
|
||||||
|
val1e 8
|
||||||
|
val1e 8
|
||||||
|
val1e 8
|
||||||
|
|
Loading…
Reference in a new issue