[SPARK-35845][SQL] OuterReference resolution should reject ambiguous column names
### What changes were proposed in this pull request? The current OuterReference resolution is a bit weird: when the outer plan has more than one child, it resolves OuterReference from the output of each child, one by one, left to right. This is incorrect in the case of join, as the column name can be ambiguous if both left and right sides output this column. This PR fixes this bug by resolving OuterReference with `outerPlan.resolveChildren`, instead of something like `outerPlan.children.foreach(_.resolve(...))` ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? The problem only occurs in join, and join condition doesn't support correlated subquery yet. So this PR only improves the error message. Before this PR, people see ``` java.lang.UnsupportedOperationException Cannot generate code for expression: outer(t1a#291) ``` ### How was this patch tested? a new test Closes #33004 from cloud-fan/outer-ref. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
parent
df55945804
commit
20edfdd39a
|
@ -2285,8 +2285,8 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
}
|
||||
|
||||
/**
|
||||
* Resolve the correlated expressions in a subquery by using the an outer plans' references. All
|
||||
* resolved outer references are wrapped in an [[OuterReference]]
|
||||
* Resolve the correlated expressions in a subquery, as if the expressions live in the outer
|
||||
* plan. All resolved outer references are wrapped in an [[OuterReference]]
|
||||
*/
|
||||
private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
|
||||
plan.resolveOperatorsDownWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) {
|
||||
|
@ -2295,7 +2295,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
case u @ UnresolvedAttribute(nameParts) =>
|
||||
withPosition(u) {
|
||||
try {
|
||||
outer.resolve(nameParts, resolver) match {
|
||||
outer.resolveChildren(nameParts, resolver) match {
|
||||
case Some(outerAttr) => wrapOuterReference(outerAttr)
|
||||
case None => u
|
||||
}
|
||||
|
@ -2317,7 +2317,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
*/
|
||||
private def resolveSubQuery(
|
||||
e: SubqueryExpression,
|
||||
plans: Seq[LogicalPlan])(
|
||||
outer: LogicalPlan)(
|
||||
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
|
||||
// Step 1: Resolve the outer expressions.
|
||||
var previous: LogicalPlan = null
|
||||
|
@ -2328,10 +2328,8 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
current = executeSameContext(current)
|
||||
|
||||
// Use the outer references to resolve the subquery plan if it isn't resolved yet.
|
||||
val i = plans.iterator
|
||||
val afterResolve = current
|
||||
while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) {
|
||||
current = resolveOuterReferences(current, i.next())
|
||||
if (!current.resolved) {
|
||||
current = resolveOuterReferences(current, outer)
|
||||
}
|
||||
} while (!current.resolved && !current.fastEquals(previous))
|
||||
|
||||
|
@ -2354,20 +2352,20 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
* (2) Any aggregate expression(s) that reference outer attributes are pushed down to
|
||||
* outer plan to get evaluated.
|
||||
*/
|
||||
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
|
||||
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
|
||||
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
|
||||
case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
|
||||
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
|
||||
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
|
||||
case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
|
||||
resolveSubQuery(e, plans)(Exists(_, _, exprId))
|
||||
resolveSubQuery(e, outer)(Exists(_, _, exprId))
|
||||
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
|
||||
if values.forall(_.resolved) && !l.resolved =>
|
||||
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
|
||||
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
|
||||
ListQuery(plan, exprs, exprId, plan.output)
|
||||
})
|
||||
InSubquery(values, expr.asInstanceOf[ListQuery])
|
||||
case s @ LateralSubquery(sub, _, exprId, _) if !sub.resolved =>
|
||||
resolveSubQuery(s, plans)(LateralSubquery(_, _, exprId))
|
||||
resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2377,14 +2375,17 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
|
||||
_.containsPattern(PLAN_EXPRESSION), ruleId) {
|
||||
case j: LateralJoin if j.left.resolved =>
|
||||
resolveSubQueries(j, j.children)
|
||||
// We can't pass `LateralJoin` as the outer plan, as its right child is not resolved yet
|
||||
// and we can't call `LateralJoin.resolveChildren` to resolve outer references. Here we
|
||||
// create a fake Project node as the outer plan.
|
||||
resolveSubQueries(j, Project(Nil, j.left))
|
||||
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
|
||||
case q: UnaryNode if q.childrenResolved =>
|
||||
resolveSubQueries(q, q.children)
|
||||
resolveSubQueries(q, q)
|
||||
case j: Join if j.childrenResolved && j.duplicateResolved =>
|
||||
resolveSubQueries(j, j.children)
|
||||
resolveSubQueries(j, j)
|
||||
case s: SupportsSubquery if s.childrenResolved =>
|
||||
resolveSubQueries(s, s.children)
|
||||
resolveSubQueries(s, s)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -258,13 +258,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
|
|||
def apply(
|
||||
innerPlan: LogicalPlan,
|
||||
outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
|
||||
apply(innerPlan, Seq(outerPlan))
|
||||
}
|
||||
|
||||
def apply(
|
||||
innerPlan: LogicalPlan,
|
||||
outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
|
||||
val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet))
|
||||
val outputPlanInputAttrs = outerPlan.inputSet
|
||||
|
||||
// The return type of the recursion.
|
||||
// The first parameter is a new logical plan with correlation eliminated.
|
||||
|
@ -486,7 +480,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
|
|||
}
|
||||
}
|
||||
val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty)
|
||||
val (plan, conditions) = deduplicate(newChild, joinCond, outputSet)
|
||||
val (plan, conditions) = deduplicate(newChild, joinCond, outputPlanInputAttrs)
|
||||
(plan, stripOuterReferences(conditions))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -220,7 +220,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
*/
|
||||
private def pullOutCorrelatedPredicates(
|
||||
sub: LogicalPlan,
|
||||
outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
|
||||
outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
|
||||
val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
|
||||
|
||||
/** Determine which correlated predicate references are missing from this plan. */
|
||||
|
@ -272,10 +272,10 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
// In case of a collision, change the subquery plan's output to use
|
||||
// different attribute by creating alias(s).
|
||||
val baseConditions = predicateMap.values.flatten.toSeq
|
||||
val (newPlan, newCond) = if (outer.nonEmpty) {
|
||||
val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
|
||||
val outerPlanInputAttrs = outer.inputSet
|
||||
val (newPlan, newCond) = if (outerPlanInputAttrs.nonEmpty) {
|
||||
val (plan, deDuplicatedConditions) =
|
||||
DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outputSet)
|
||||
DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outerPlanInputAttrs)
|
||||
(plan, stripOuterReferences(deDuplicatedConditions))
|
||||
} else {
|
||||
(transformed, stripOuterReferences(baseConditions))
|
||||
|
@ -283,7 +283,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
(newPlan, newCond)
|
||||
}
|
||||
|
||||
private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = {
|
||||
private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = {
|
||||
/**
|
||||
* This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule.
|
||||
* In the first call to rewriteSubqueries, all the outer references from the subplan are
|
||||
|
@ -296,7 +296,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
if (newCond.isEmpty) oldCond else newCond
|
||||
}
|
||||
|
||||
def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
|
||||
def decorrelate(sub: LogicalPlan, outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
|
||||
if (SQLConf.get.decorrelateInnerQueryEnabled) {
|
||||
DecorrelateInnerQuery(sub, outer)
|
||||
} else {
|
||||
|
@ -306,16 +306,16 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
|
||||
plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
|
||||
case ScalarSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = decorrelate(sub, outerPlans)
|
||||
val (newPlan, newCond) = decorrelate(sub, plan)
|
||||
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
|
||||
case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
|
||||
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions))
|
||||
case ListQuery(sub, children, exprId, childOutputs, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
|
||||
ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
|
||||
case LateralSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = decorrelate(sub, outerPlans)
|
||||
val (newPlan, newCond) = decorrelate(sub, plan)
|
||||
LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
|
||||
}
|
||||
}
|
||||
|
@ -326,7 +326,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
|
||||
_.containsPattern(PLAN_EXPRESSION)) {
|
||||
case j: LateralJoin =>
|
||||
val newPlan = rewriteSubQueries(j, j.children)
|
||||
val newPlan = rewriteSubQueries(j)
|
||||
// Since a lateral join's output depends on its left child output and its lateral subquery's
|
||||
// plan output, we need to trim the domain attributes added to the subquery's plan output
|
||||
// to preserve the original output of the join.
|
||||
|
@ -337,9 +337,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
}
|
||||
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
|
||||
case q: UnaryNode =>
|
||||
rewriteSubQueries(q, q.children)
|
||||
rewriteSubQueries(q)
|
||||
case s: SupportsSubquery =>
|
||||
rewriteSubQueries(s, s.children)
|
||||
rewriteSubQueries(s)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
|
|||
outerPlan: LogicalPlan,
|
||||
correctAnswer: LogicalPlan,
|
||||
conditions: Seq[Expression]): Unit = {
|
||||
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
|
||||
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
|
||||
assert(!hasOuterReferences(outputPlan))
|
||||
comparePlans(outputPlan, correctAnswer)
|
||||
assert(joinCond.length == conditions.length)
|
||||
|
@ -90,7 +90,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
|
|||
Project(Seq(a),
|
||||
Filter(OuterReference(a) === a,
|
||||
testRelation))
|
||||
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
|
||||
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
|
||||
val a1 = outputPlan.output.head
|
||||
val correctAnswer =
|
||||
Project(Seq(Alias(a, a1.name)(a1.exprId)),
|
||||
|
@ -197,7 +197,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
|
|||
Inner,
|
||||
Some(OuterReference(x) === a),
|
||||
JoinHint.NONE)
|
||||
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan) }
|
||||
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan.select()) }
|
||||
assert(error.getMessage.contains("Correlated column is not allowed in join"))
|
||||
}
|
||||
|
||||
|
|
|
@ -71,3 +71,12 @@ WHERE t1a IN (SELECT t2a
|
|||
WHERE EXISTS (SELECT min(t2a)
|
||||
FROM t3));
|
||||
|
||||
CREATE TEMPORARY VIEW t1_copy AS SELECT * FROM VALUES
|
||||
(1, 2, 3)
|
||||
AS t1(t1a, t1b, t1c);
|
||||
|
||||
-- invalid because column name `t1a` is ambiguous in the subquery.
|
||||
SELECT t1.t1a
|
||||
FROM t1
|
||||
JOIN t1_copy
|
||||
ON EXISTS (SELECT 1 FROM t2 WHERE t2a > t1a)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 8
|
||||
-- Number of queries: 10
|
||||
|
||||
|
||||
-- !query
|
||||
|
@ -116,3 +116,25 @@ Aggregate [min(outer(t2a#x)) AS min(outer(t2.t2a))#x]
|
|||
+- Project [t3a#x, t3b#x, t3c#x]
|
||||
+- SubqueryAlias t3
|
||||
+- LocalRelation [t3a#x, t3b#x, t3c#x]
|
||||
|
||||
|
||||
-- !query
|
||||
CREATE TEMPORARY VIEW t1_copy AS SELECT * FROM VALUES
|
||||
(1, 2, 3)
|
||||
AS t1(t1a, t1b, t1c)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
|
||||
|
||||
|
||||
-- !query
|
||||
SELECT t1.t1a
|
||||
FROM t1
|
||||
JOIN t1_copy
|
||||
ON EXISTS (SELECT 1 FROM t2 WHERE t2a > t1a)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
cannot resolve 't1a' given input columns: [t2.t2a, t2.t2b, t2.t2c]; line 4 pos 44
|
||||
|
|
Loading…
Reference in a new issue