[SPARK-35618][SQL] Resolve star expressions in subqueries using outer query plans

### What changes were proposed in this pull request?
This PR supports resolving star expressions in subqueries using outer query plans.

### Why are the changes needed?
Currently, Spark can only resolve star expressions using the inner query plan when resolving subqueries. Instead, it should also be able to resolve star expressions using the outer query plans.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Unit tests

Closes #32787 from allisonwang-db/spark-35618-resolve-star-in-subquery.

Lead-authored-by: allisonwang-db <allison.wang@databricks.com>
Co-authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
allisonwang-db 2021-07-01 09:22:55 +00:00 committed by Wenchen Fan
parent f2492772ba
commit f281736fbd
8 changed files with 218 additions and 26 deletions

View file

@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import scala.util.{Failure, Random, Success, Try}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
@ -94,8 +94,9 @@ object FakeV2SessionCatalog extends TableCatalog {
}
/**
* Provides a way to keep state during the analysis, mostly for resolving views. This enables us to
* decouple the concerns of analysis environment from the catalog.
* Provides a way to keep state during the analysis, mostly for resolving views and subqueries.
* This enables us to decouple the concerns of analysis environment from the catalog and resolve
* star expressions in subqueries that reference the outer query plans.
* The state that is kept here is per-query.
*
* Note this is thread local.
@ -115,6 +116,8 @@ object FakeV2SessionCatalog extends TableCatalog {
* if `t` was a permanent table when the current view was created, it
* should still be a permanent table when resolving the current view,
* even if a temp view `t` has been created.
* @param outerPlan The query plan from the outer query that can be used to resolve star
* expressions in a subquery.
*/
case class AnalysisContext(
catalogAndNamespace: Seq[String] = Nil,
@ -122,7 +125,8 @@ case class AnalysisContext(
maxNestedViewDepth: Int = -1,
relationCache: mutable.Map[Seq[String], LogicalPlan] = mutable.Map.empty,
referredTempViewNames: Seq[Seq[String]] = Seq.empty,
referredTempFunctionNames: Seq[String] = Seq.empty)
referredTempFunctionNames: Seq[String] = Seq.empty,
outerPlan: Option[LogicalPlan] = None)
object AnalysisContext {
private val value = new ThreadLocal[AnalysisContext]() {
@ -152,6 +156,13 @@ object AnalysisContext {
set(context)
try f finally { set(originContext) }
}
def withOuterPlan[A](outerPlan: LogicalPlan)(f: => A): A = {
val originContext = value.get()
val context = originContext.copy(outerPlan = Some(outerPlan))
set(context)
try f finally { set(originContext) }
}
}
/**
@ -1579,6 +1590,30 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}
// Expand the star expression using the input plan first. If failed, try resolve
// the star expression using the outer query plan and wrap the resolved attributes
// in outer references. Otherwise throw the original exception.
private def expand(s: Star, plan: LogicalPlan): Seq[NamedExpression] = {
withPosition(s) {
try {
s.expand(plan, resolver)
} catch {
case e: AnalysisException =>
AnalysisContext.get.outerPlan.map {
// Only Project and Aggregate can host star expressions.
case u @ (_: Project | _: Aggregate) =>
Try(s.expand(u.children.head, resolver)) match {
case Success(expanded) => expanded.map(wrapOuterReference)
case Failure(_) => throw e
}
// Do not use the outer plan to resolve the star expression
// since the star usage is invalid.
case _ => throw e
}.getOrElse { throw e }
}
}
}
/**
* Build a project list for Project/Aggregate and expand the star if possible
*/
@ -1587,9 +1622,9 @@ class Analyzer(override val catalogManager: CatalogManager)
child: LogicalPlan): Seq[NamedExpression] = {
exprs.flatMap {
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
case s: Star => s.expand(child, resolver)
case s: Star => expand(s, child)
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
case UnresolvedAlias(s: Star, _) => expand(s, child)
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
@ -1622,28 +1657,28 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}
f1.copy(arguments = f1.arguments.flatMap {
case s: Star => s.expand(child, resolver)
case s: Star => expand(s, child)
case o => o :: Nil
})
case c: CreateNamedStruct if containsStar(c.valExprs) =>
val newChildren = c.children.grouped(2).flatMap {
case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children
case Seq(k, s : Star) => CreateStruct(expand(s, child)).children
case kv => kv
}
c.copy(children = newChildren.toList )
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case s: Star => expand(s, child)
case o => o :: Nil
})
case p: Murmur3Hash if containsStar(p.children) =>
p.copy(children = p.children.flatMap {
case s: Star => s.expand(child, resolver)
case s: Star => expand(s, child)
case o => o :: Nil
})
case p: XxHash64 if containsStar(p.children) =>
p.copy(children = p.children.flatMap {
case s: Star => s.expand(child, resolver)
case s: Star => expand(s, child)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
@ -2284,14 +2319,6 @@ class Analyzer(override val catalogManager: CatalogManager)
* Note: CTEs are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
/**
* Wrap attributes in the expression with [[OuterReference]]s.
*/
private def wrapOuterReference[E <: Expression](e: E): E = {
e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E]
}
/**
* Resolve the correlated expressions in a subquery, as if the expressions live in the outer
* plan. All resolved outer references are wrapped in an [[OuterReference]]
@ -2333,7 +2360,9 @@ class Analyzer(override val catalogManager: CatalogManager)
do {
// Try to resolve the subquery plan using the regular analyzer.
previous = current
current = executeSameContext(current)
current = AnalysisContext.withOuterPlan(outer) {
executeSameContext(current)
}
// Use the outer references to resolve the subquery plan if it isn't resolved yet.
if (!current.resolved) {

View file

@ -185,6 +185,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// cannot resolve '${a.sql}' given input columns: [$from]
a.failAnalysis(errorClass = "MISSING_COLUMN", messageParameters = Seq(a.sql, from))
case s: Star =>
withPosition(s) {
throw QueryCompilationErrors.invalidStarUsageError(operator.nodeName)
}
case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>

View file

@ -144,6 +144,13 @@ object SubExprUtils extends PredicateHelper {
}
}
/**
* Wrap attributes in the expression with [[OuterReference]]s.
*/
def wrapOuterReference[E <: Expression](e: E): E = {
e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E]
}
/**
* Given a logical plan, returns TRUE if it has an outer reference and false otherwise.
*/

View file

@ -812,4 +812,49 @@ class AnalysisErrorSuite extends AnalysisTest {
// UnresolvedHint be removed by batch `Remove Unresolved Hints`
assertAnalysisSuccess(plan, true)
}
test("SPARK-35618: Resolve star expressions in subqueries") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val t0 = OneRowRelation()
val t1 = LocalRelation(a, b).as("t1")
// t1.* in the subquery should be resolved into outer(t1.a) and outer(t1.b).
assertAnalysisError(
Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1),
"Scalar subquery must return only one column, but got 2" :: Nil)
// array(t1.*) in the subquery should be resolved into array(outer(t1.a), outer(t1.b))
val array = CreateArray(Seq(star("t1")))
assertAnalysisError(
Project(ScalarSubquery(t0.select(array)).as("sub") :: Nil, t1),
"Expressions referencing the outer query are not supported outside" +
" of WHERE/HAVING clauses" :: Nil)
// t2.* cannot be resolved and the error should be the initial analysis exception.
assertAnalysisError(
Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1),
"cannot resolve 't2.*' given input columns ''" :: Nil
)
}
test("SPARK-35618: Invalid star usage in subqueries") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", IntegerType)()
val t1 = LocalRelation(a, b).as("t1")
val t2 = LocalRelation(b, c).as("t2")
// SELECT * FROM t1 WHERE a = (SELECT sum(c) FROM t2 WHERE t1.* = t2.b)
assertAnalysisError(
Filter(EqualTo(a, ScalarSubquery(t2.select(sum(c)).where(star("t1") === b))), t1),
"Invalid usage of '*' in Filter" :: Nil
)
// SELECT * FROM t1 JOIN t2 ON (EXISTS (SELECT 1 FROM t2 WHERE t1.* = b))
assertAnalysisError(
t1.join(t2, condition = Some(Exists(t2.select(1).where(star("t1") === b)))),
"Invalid usage of '*' in Filter" :: Nil
)
}
}

View file

@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
@ -177,4 +178,66 @@ class ResolveSubquerySuite extends AnalysisTest {
condition = Some(sum('a) === sum('c)))
assertAnalysisError(plan, Seq("Invalid expressions: [sum(a), sum(c)]"))
}
test("SPARK-35618: lateral join with star expansion") {
val outerA = OuterReference(a.withQualifier(Seq("t1"))).as(a.name)
val outerB = OuterReference(b.withQualifier(Seq("t1"))).as(b.name)
// SELECT * FROM t1, LATERAL (SELECT *)
checkAnalysis(
lateralJoin(t1.as("t1"), t0.select(star())),
LateralJoin(t1, LateralSubquery(Project(Nil, t0)), Inner, None)
)
// SELECT * FROM t1, LATERAL (SELECT t1.*)
checkAnalysis(
lateralJoin(t1.as("t1"), t0.select(star("t1"))),
LateralJoin(t1, LateralSubquery(Project(Seq(outerA, outerB), t0), Seq(a, b)), Inner, None)
)
// SELECT * FROM t1, LATERAL (SELECT * FROM t2)
checkAnalysis(
lateralJoin(t1.as("t1"), t2.select(star())),
LateralJoin(t1, LateralSubquery(Project(Seq(b, c), t2)), Inner, None)
)
// SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2)
checkAnalysis(
lateralJoin(t1.as("t1"), t2.as("t2").select(star("t1"), star("t2"))),
LateralJoin(t1,
LateralSubquery(Project(Seq(outerA, outerB, b, c), t2.as("t2")), Seq(a, b)), Inner, None)
)
// SELECT * FROM t1, LATERAL (SELECT t2.*)
assertAnalysisError(
lateralJoin(t1.as("t1"), t0.select(star("t2"))),
Seq("cannot resolve 't2.*' given input columns ''")
)
// Check case sensitivities.
// SELECT * FROM t1, LATERAL (SELECT T1.*)
val plan = lateralJoin(t1.as("t1"), t0.select(star("T1")))
assertAnalysisError(plan, "cannot resolve 'T1.*' given input columns ''" :: Nil)
assertAnalysisSuccess(plan, caseSensitive = false)
}
test("SPARK-35618: lateral join with star expansion in functions") {
val outerA = OuterReference(a.withQualifier(Seq("t1")))
val outerB = OuterReference(b.withQualifier(Seq("t1")))
val array = CreateArray(Seq(star("t1")))
val newArray = CreateArray(Seq(outerA, outerB))
checkAnalysis(
lateralJoin(t1.as("t1"), t0.select(array)),
LateralJoin(t1,
LateralSubquery(t0.select(newArray.as(newArray.sql)), Seq(a, b)), Inner, None)
)
assertAnalysisError(
lateralJoin(t1.as("t1"), t0.select(Count(star("t1")))),
Seq("Invalid usage of '*' in expression 'count'")
)
}
test("SPARK-35618: lateral join with struct type star expansion") {
// SELECT * FROM t4, LATERAL (SELECT x.*)
checkAnalysis(
lateralJoin(t4, t0.select(star("x"))),
LateralJoin(t4, LateralSubquery(
Project(Seq(GetStructField(OuterReference(x), 0).as(a.name)), t0), Seq(x)),
Inner, None)
)
}
}

View file

@ -12,9 +12,12 @@ SELECT * FROM t1, LATERAL (SELECT t1.c1 + t2.c1 FROM t2);
-- lateral join with star expansion
SELECT * FROM t1, LATERAL (SELECT *);
SELECT * FROM t1, LATERAL (SELECT * FROM t2);
-- TODO(SPARK-35618): resolve star expressions in subquery
-- SELECT * FROM t1, LATERAL (SELECT t1.*);
-- SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2);
SELECT * FROM t1, LATERAL (SELECT t1.*);
SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2);
SELECT * FROM t1, LATERAL (SELECT t1.* FROM t2 AS t1);
-- expect error: cannot resolve 't1.*'
-- TODO: Currently we don't allow deep correlation so t1.* cannot be resolved using the outermost query.
SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2, LATERAL (SELECT t1.*, t2.*, t3.* FROM t2 AS t3));
-- lateral join with different join types
SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3;

View file

@ -195,7 +195,7 @@ SELECT t1.x.y.* FROM t1
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 't1.x.y.*' given input columns 'i1'
cannot resolve 't1.x.y.*' given input columns 'i1'; line 1 pos 7
-- !query

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 40
-- Number of queries: 44
-- !query
@ -80,6 +80,46 @@ struct<c1:int,c2:int,c1:int,c2:int>
1 2 0 3
-- !query
SELECT * FROM t1, LATERAL (SELECT t1.*)
-- !query schema
struct<c1:int,c2:int,c1:int,c2:int>
-- !query output
0 1 0 1
1 2 1 2
-- !query
SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2)
-- !query schema
struct<c1:int,c2:int,c1:int,c2:int,c1:int,c2:int>
-- !query output
0 1 0 1 0 2
0 1 0 1 0 3
1 2 1 2 0 2
1 2 1 2 0 3
-- !query
SELECT * FROM t1, LATERAL (SELECT t1.* FROM t2 AS t1)
-- !query schema
struct<c1:int,c2:int,c1:int,c2:int>
-- !query output
0 1 0 2
0 1 0 3
1 2 0 2
1 2 0 3
-- !query
SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2, LATERAL (SELECT t1.*, t2.*, t3.* FROM t2 AS t3))
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 't1.*' given input columns 'c1, c2'; line 1 pos 70
-- !query
SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3
-- !query schema