[SPARK-35845][SQL] OuterReference resolution should reject ambiguous column names

### What changes were proposed in this pull request?

The current OuterReference resolution is a bit weird: when the outer plan has more than one child, it resolves OuterReference from the output of each child, one by one, left to right.

This is incorrect in the case of join, as the column name can be ambiguous if both left and right sides output this column.

This PR fixes this bug by resolving OuterReference with `outerPlan.resolveChildren`, instead of something like `outerPlan.children.foreach(_.resolve(...))`

### Why are the changes needed?

bug fix

### Does this PR introduce _any_ user-facing change?

The problem only occurs in join, and join condition doesn't support correlated subquery yet. So this PR only improves the error message. Before this PR, people see
```
java.lang.UnsupportedOperationException
Cannot generate code for expression: outer(t1a#291)
```

### How was this patch tested?

a new test

Closes #33004 from cloud-fan/outer-ref.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Wenchen Fan 2021-06-23 14:32:34 +08:00 committed by Gengliang Wang
parent df55945804
commit 20edfdd39a
6 changed files with 68 additions and 42 deletions

View file

@ -2285,8 +2285,8 @@ class Analyzer(override val catalogManager: CatalogManager)
}
/**
* Resolve the correlated expressions in a subquery by using the an outer plans' references. All
* resolved outer references are wrapped in an [[OuterReference]]
* Resolve the correlated expressions in a subquery, as if the expressions live in the outer
* plan. All resolved outer references are wrapped in an [[OuterReference]]
*/
private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsDownWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) {
@ -2295,7 +2295,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case u @ UnresolvedAttribute(nameParts) =>
withPosition(u) {
try {
outer.resolve(nameParts, resolver) match {
outer.resolveChildren(nameParts, resolver) match {
case Some(outerAttr) => wrapOuterReference(outerAttr)
case None => u
}
@ -2317,7 +2317,7 @@ class Analyzer(override val catalogManager: CatalogManager)
*/
private def resolveSubQuery(
e: SubqueryExpression,
plans: Seq[LogicalPlan])(
outer: LogicalPlan)(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
// Step 1: Resolve the outer expressions.
var previous: LogicalPlan = null
@ -2328,10 +2328,8 @@ class Analyzer(override val catalogManager: CatalogManager)
current = executeSameContext(current)
// Use the outer references to resolve the subquery plan if it isn't resolved yet.
val i = plans.iterator
val afterResolve = current
while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) {
current = resolveOuterReferences(current, i.next())
if (!current.resolved) {
current = resolveOuterReferences(current, outer)
}
} while (!current.resolved && !current.fastEquals(previous))
@ -2354,20 +2352,20 @@ class Analyzer(override val catalogManager: CatalogManager)
* (2) Any aggregate expression(s) that reference outer attributes are pushed down to
* outer plan to get evaluated.
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
resolveSubQuery(e, outer)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
})
InSubquery(values, expr.asInstanceOf[ListQuery])
case s @ LateralSubquery(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(s, plans)(LateralSubquery(_, _, exprId))
resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
}
}
@ -2377,14 +2375,17 @@ class Analyzer(override val catalogManager: CatalogManager)
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(PLAN_EXPRESSION), ruleId) {
case j: LateralJoin if j.left.resolved =>
resolveSubQueries(j, j.children)
// We can't pass `LateralJoin` as the outer plan, as its right child is not resolved yet
// and we can't call `LateralJoin.resolveChildren` to resolve outer references. Here we
// create a fake Project node as the outer plan.
resolveSubQueries(j, Project(Nil, j.left))
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
case q: UnaryNode if q.childrenResolved =>
resolveSubQueries(q, q.children)
resolveSubQueries(q, q)
case j: Join if j.childrenResolved && j.duplicateResolved =>
resolveSubQueries(j, j.children)
resolveSubQueries(j, j)
case s: SupportsSubquery if s.childrenResolved =>
resolveSubQueries(s, s.children)
resolveSubQueries(s, s)
}
}

View file

@ -258,13 +258,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
def apply(
innerPlan: LogicalPlan,
outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
apply(innerPlan, Seq(outerPlan))
}
def apply(
innerPlan: LogicalPlan,
outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet))
val outputPlanInputAttrs = outerPlan.inputSet
// The return type of the recursion.
// The first parameter is a new logical plan with correlation eliminated.
@ -486,7 +480,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
}
}
val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty)
val (plan, conditions) = deduplicate(newChild, joinCond, outputSet)
val (plan, conditions) = deduplicate(newChild, joinCond, outputPlanInputAttrs)
(plan, stripOuterReferences(conditions))
}
}

View file

@ -220,7 +220,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
*/
private def pullOutCorrelatedPredicates(
sub: LogicalPlan,
outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
/** Determine which correlated predicate references are missing from this plan. */
@ -272,10 +272,10 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
// In case of a collision, change the subquery plan's output to use
// different attribute by creating alias(s).
val baseConditions = predicateMap.values.flatten.toSeq
val (newPlan, newCond) = if (outer.nonEmpty) {
val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
val outerPlanInputAttrs = outer.inputSet
val (newPlan, newCond) = if (outerPlanInputAttrs.nonEmpty) {
val (plan, deDuplicatedConditions) =
DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outputSet)
DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outerPlanInputAttrs)
(plan, stripOuterReferences(deDuplicatedConditions))
} else {
(transformed, stripOuterReferences(baseConditions))
@ -283,7 +283,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
(newPlan, newCond)
}
private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = {
private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = {
/**
* This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule.
* In the first call to rewriteSubqueries, all the outer references from the subplan are
@ -296,7 +296,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
if (newCond.isEmpty) oldCond else newCond
}
def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
def decorrelate(sub: LogicalPlan, outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
if (SQLConf.get.decorrelateInnerQueryEnabled) {
DecorrelateInnerQuery(sub, outer)
} else {
@ -306,16 +306,16 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, outerPlans)
val (newPlan, newCond) = decorrelate(sub, plan)
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions))
case ListQuery(sub, children, exprId, childOutputs, conditions) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
case LateralSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, outerPlans)
val (newPlan, newCond) = decorrelate(sub, plan)
LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
}
}
@ -326,7 +326,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(PLAN_EXPRESSION)) {
case j: LateralJoin =>
val newPlan = rewriteSubQueries(j, j.children)
val newPlan = rewriteSubQueries(j)
// Since a lateral join's output depends on its left child output and its lateral subquery's
// plan output, we need to trim the domain attributes added to the subquery's plan output
// to preserve the original output of the join.
@ -337,9 +337,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
case q: UnaryNode =>
rewriteSubQueries(q, q.children)
rewriteSubQueries(q)
case s: SupportsSubquery =>
rewriteSubQueries(s, s.children)
rewriteSubQueries(s)
}
}

View file

@ -44,7 +44,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
outerPlan: LogicalPlan,
correctAnswer: LogicalPlan,
conditions: Seq[Expression]): Unit = {
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
assert(!hasOuterReferences(outputPlan))
comparePlans(outputPlan, correctAnswer)
assert(joinCond.length == conditions.length)
@ -90,7 +90,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
Project(Seq(a),
Filter(OuterReference(a) === a,
testRelation))
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
val a1 = outputPlan.output.head
val correctAnswer =
Project(Seq(Alias(a, a1.name)(a1.exprId)),
@ -197,7 +197,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
Inner,
Some(OuterReference(x) === a),
JoinHint.NONE)
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan) }
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan.select()) }
assert(error.getMessage.contains("Correlated column is not allowed in join"))
}

View file

@ -71,3 +71,12 @@ WHERE t1a IN (SELECT t2a
WHERE EXISTS (SELECT min(t2a)
FROM t3));
CREATE TEMPORARY VIEW t1_copy AS SELECT * FROM VALUES
(1, 2, 3)
AS t1(t1a, t1b, t1c);
-- invalid because column name `t1a` is ambiguous in the subquery.
SELECT t1.t1a
FROM t1
JOIN t1_copy
ON EXISTS (SELECT 1 FROM t2 WHERE t2a > t1a)

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 8
-- Number of queries: 10
-- !query
@ -116,3 +116,25 @@ Aggregate [min(outer(t2a#x)) AS min(outer(t2.t2a))#x]
+- Project [t3a#x, t3b#x, t3c#x]
+- SubqueryAlias t3
+- LocalRelation [t3a#x, t3b#x, t3c#x]
-- !query
CREATE TEMPORARY VIEW t1_copy AS SELECT * FROM VALUES
(1, 2, 3)
AS t1(t1a, t1b, t1c)
-- !query schema
struct<>
-- !query output
-- !query
SELECT t1.t1a
FROM t1
JOIN t1_copy
ON EXISTS (SELECT 1 FROM t2 WHERE t2a > t1a)
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 't1a' given input columns: [t2.t2a, t2.t2b, t2.t2c]; line 4 pos 44