[SPARK-21759][SQL] In.checkInputDataTypes should not wrongly report unresolved plans for IN correlated subquery
## What changes were proposed in this pull request? With the check for structural integrity proposed in SPARK-21726, it is found that the optimization rule `PullupCorrelatedPredicates` can produce unresolved plans. For a correlated IN query looks like: SELECT t1.a FROM t1 WHERE t1.a IN (SELECT t2.c FROM t2 WHERE t1.b < t2.d); The query plan might look like: Project [a#0] +- Filter a#0 IN (list#4 [b#1]) : +- Project [c#2] : +- Filter (outer(b#1) < d#3) : +- LocalRelation <empty>, [c#2, d#3] +- LocalRelation <empty>, [a#0, b#1] After `PullupCorrelatedPredicates`, it produces query plan like: 'Project [a#0] +- 'Filter a#0 IN (list#4 [(b#1 < d#3)]) : +- Project [c#2, d#3] : +- LocalRelation <empty>, [c#2, d#3] +- LocalRelation <empty>, [a#0, b#1] Because the correlated predicate involves another attribute `d#3` in subquery, it has been pulled out and added into the `Project` on the top of the subquery. When `list` in `In` contains just one `ListQuery`, `In.checkInputDataTypes` checks if the size of `value` expressions matches the output size of subquery. In the above example, there is only `value` expression and the subquery output has two attributes `c#2, d#3`, so it fails the check and `In.resolved` returns `false`. We should not let `In.checkInputDataTypes` wrongly report unresolved plans to fail the structural integrity check. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #18968 from viirya/SPARK-21759.
This commit is contained in:
parent
9e33954ddf
commit
183d4cb71f
|
@ -1286,8 +1286,10 @@ class Analyzer(
|
||||||
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
|
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
|
||||||
case e @ Exists(sub, _, exprId) if !sub.resolved =>
|
case e @ Exists(sub, _, exprId) if !sub.resolved =>
|
||||||
resolveSubQuery(e, plans)(Exists(_, _, exprId))
|
resolveSubQuery(e, plans)(Exists(_, _, exprId))
|
||||||
case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
|
case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
|
||||||
val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId))
|
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
|
||||||
|
ListQuery(plan, exprs, exprId, plan.output)
|
||||||
|
})
|
||||||
In(value, Seq(expr))
|
In(value, Seq(expr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -402,7 +402,7 @@ object TypeCoercion {
|
||||||
|
|
||||||
// Handle type casting required between value expression and subquery output
|
// Handle type casting required between value expression and subquery output
|
||||||
// in IN subquery.
|
// in IN subquery.
|
||||||
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
|
case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
|
||||||
if !i.resolved && flattenExpr(a).length == sub.output.length =>
|
if !i.resolved && flattenExpr(a).length == sub.output.length =>
|
||||||
// LHS is the value expression of IN subquery.
|
// LHS is the value expression of IN subquery.
|
||||||
val lhs = flattenExpr(a)
|
val lhs = flattenExpr(a)
|
||||||
|
@ -434,7 +434,8 @@ object TypeCoercion {
|
||||||
case _ => CreateStruct(castedLhs)
|
case _ => CreateStruct(castedLhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
|
val newSub = Project(castedRhs, sub)
|
||||||
|
In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
|
||||||
} else {
|
} else {
|
||||||
i
|
i
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,32 +138,33 @@ case class Not(child: Expression)
|
||||||
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
||||||
|
|
||||||
require(list != null, "list should not be null")
|
require(list != null, "list should not be null")
|
||||||
|
|
||||||
override def checkInputDataTypes(): TypeCheckResult = {
|
override def checkInputDataTypes(): TypeCheckResult = {
|
||||||
list match {
|
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
|
||||||
case ListQuery(sub, _, _) :: Nil =>
|
if (mismatchOpt.isDefined) {
|
||||||
val valExprs = value match {
|
list match {
|
||||||
case cns: CreateNamedStruct => cns.valExprs
|
case ListQuery(_, _, _, childOutputs) :: Nil =>
|
||||||
case expr => Seq(expr)
|
val valExprs = value match {
|
||||||
}
|
case cns: CreateNamedStruct => cns.valExprs
|
||||||
if (valExprs.length != sub.output.length) {
|
case expr => Seq(expr)
|
||||||
TypeCheckResult.TypeCheckFailure(
|
|
||||||
s"""
|
|
||||||
|The number of columns in the left hand side of an IN subquery does not match the
|
|
||||||
|number of columns in the output of subquery.
|
|
||||||
|#columns in left hand side: ${valExprs.length}.
|
|
||||||
|#columns in right hand side: ${sub.output.length}.
|
|
||||||
|Left side columns:
|
|
||||||
|[${valExprs.map(_.sql).mkString(", ")}].
|
|
||||||
|Right side columns:
|
|
||||||
|[${sub.output.map(_.sql).mkString(", ")}].
|
|
||||||
""".stripMargin)
|
|
||||||
} else {
|
|
||||||
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
|
|
||||||
case (l, r) if l.dataType != r.dataType =>
|
|
||||||
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
|
|
||||||
case _ => None
|
|
||||||
}
|
}
|
||||||
if (mismatchedColumns.nonEmpty) {
|
if (valExprs.length != childOutputs.length) {
|
||||||
|
TypeCheckResult.TypeCheckFailure(
|
||||||
|
s"""
|
||||||
|
|The number of columns in the left hand side of an IN subquery does not match the
|
||||||
|
|number of columns in the output of subquery.
|
||||||
|
|#columns in left hand side: ${valExprs.length}.
|
||||||
|
|#columns in right hand side: ${childOutputs.length}.
|
||||||
|
|Left side columns:
|
||||||
|
|[${valExprs.map(_.sql).mkString(", ")}].
|
||||||
|
|Right side columns:
|
||||||
|
|[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
|
||||||
|
} else {
|
||||||
|
val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
|
||||||
|
case (l, r) if l.dataType != r.dataType =>
|
||||||
|
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
|
||||||
|
case _ => None
|
||||||
|
}
|
||||||
TypeCheckResult.TypeCheckFailure(
|
TypeCheckResult.TypeCheckFailure(
|
||||||
s"""
|
s"""
|
||||||
|The data type of one or more elements in the left hand side of an IN subquery
|
|The data type of one or more elements in the left hand side of an IN subquery
|
||||||
|
@ -173,20 +174,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
||||||
|Left side:
|
|Left side:
|
||||||
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|
||||||
|Right side:
|
|Right side:
|
||||||
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
|
|[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
|
||||||
""".stripMargin)
|
|
||||||
} else {
|
|
||||||
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
|
|
||||||
}
|
}
|
||||||
}
|
case _ =>
|
||||||
case _ =>
|
|
||||||
val mismatchOpt = list.find(l => l.dataType != value.dataType)
|
|
||||||
if (mismatchOpt.isDefined) {
|
|
||||||
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
|
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
|
||||||
s"${value.dataType} != ${mismatchOpt.get.dataType}")
|
s"${value.dataType} != ${mismatchOpt.get.dataType}")
|
||||||
} else {
|
}
|
||||||
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
|
} else {
|
||||||
}
|
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -274,9 +274,15 @@ object ScalarSubquery {
|
||||||
case class ListQuery(
|
case class ListQuery(
|
||||||
plan: LogicalPlan,
|
plan: LogicalPlan,
|
||||||
children: Seq[Expression] = Seq.empty,
|
children: Seq[Expression] = Seq.empty,
|
||||||
exprId: ExprId = NamedExpression.newExprId)
|
exprId: ExprId = NamedExpression.newExprId,
|
||||||
|
childOutputs: Seq[Attribute] = Seq.empty)
|
||||||
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
|
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
|
||||||
override def dataType: DataType = plan.schema.fields.head.dataType
|
override def dataType: DataType = if (childOutputs.length > 1) {
|
||||||
|
childOutputs.toStructType
|
||||||
|
} else {
|
||||||
|
childOutputs.head.dataType
|
||||||
|
}
|
||||||
|
override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
|
||||||
override def nullable: Boolean = false
|
override def nullable: Boolean = false
|
||||||
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
|
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
|
||||||
override def toString: String = s"list#${exprId.id} $conditionString"
|
override def toString: String = s"list#${exprId.id} $conditionString"
|
||||||
|
@ -284,7 +290,8 @@ case class ListQuery(
|
||||||
ListQuery(
|
ListQuery(
|
||||||
plan.canonicalized,
|
plan.canonicalized,
|
||||||
children.map(_.canonicalized),
|
children.map(_.canonicalized),
|
||||||
ExprId(0))
|
ExprId(0),
|
||||||
|
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
|
||||||
case (p, Not(Exists(sub, conditions, _))) =>
|
case (p, Not(Exists(sub, conditions, _))) =>
|
||||||
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
|
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
|
||||||
Join(outerPlan, sub, LeftAnti, joinCond)
|
Join(outerPlan, sub, LeftAnti, joinCond)
|
||||||
case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
|
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
|
||||||
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
|
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
|
||||||
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
|
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
|
||||||
Join(outerPlan, sub, LeftSemi, joinCond)
|
Join(outerPlan, sub, LeftSemi, joinCond)
|
||||||
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
|
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
|
||||||
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
|
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
|
||||||
// Construct the condition. A NULL in one of the conditions is regarded as a positive
|
// Construct the condition. A NULL in one of the conditions is regarded as a positive
|
||||||
// result; such a row will be filtered out by the Anti-Join operator.
|
// result; such a row will be filtered out by the Anti-Join operator.
|
||||||
|
@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
|
||||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||||
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
|
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
|
||||||
exists
|
exists
|
||||||
case In(value, Seq(ListQuery(sub, conditions, _))) =>
|
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
|
||||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||||
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
|
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
|
||||||
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
|
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
|
||||||
|
@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
||||||
case Exists(sub, children, exprId) if children.nonEmpty =>
|
case Exists(sub, children, exprId) if children.nonEmpty =>
|
||||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
||||||
Exists(newPlan, newCond, exprId)
|
Exists(newPlan, newCond, exprId)
|
||||||
case ListQuery(sub, _, exprId) =>
|
case ListQuery(sub, _, exprId, childOutputs) =>
|
||||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
||||||
ListQuery(newPlan, newCond, exprId)
|
ListQuery(newPlan, newCond, exprId, childOutputs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.catalyst.optimizer
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
|
||||||
|
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||||
|
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||||
|
|
||||||
|
class PullupCorrelatedPredicatesSuite extends PlanTest {
|
||||||
|
|
||||||
|
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||||
|
val batches =
|
||||||
|
Batch("PullupCorrelatedPredicates", Once,
|
||||||
|
PullupCorrelatedPredicates) :: Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
val testRelation = LocalRelation('a.int, 'b.double)
|
||||||
|
val testRelation2 = LocalRelation('c.int, 'd.double)
|
||||||
|
|
||||||
|
test("PullupCorrelatedPredicates should not produce unresolved plan") {
|
||||||
|
val correlatedSubquery =
|
||||||
|
testRelation2
|
||||||
|
.where('b < 'd)
|
||||||
|
.select('c)
|
||||||
|
val outerQuery =
|
||||||
|
testRelation
|
||||||
|
.where(In('a, Seq(ListQuery(correlatedSubquery))))
|
||||||
|
.select('a).analyze
|
||||||
|
assert(outerQuery.resolved)
|
||||||
|
|
||||||
|
val optimized = Optimize.execute(outerQuery)
|
||||||
|
assert(optimized.resolved)
|
||||||
|
}
|
||||||
|
}
|
|
@ -80,8 +80,7 @@ number of columns in the output of subquery.
|
||||||
Left side columns:
|
Left side columns:
|
||||||
[t1.`t1a`].
|
[t1.`t1a`].
|
||||||
Right side columns:
|
Right side columns:
|
||||||
[t2.`t2a`, t2.`t2b`].
|
[t2.`t2a`, t2.`t2b`].;
|
||||||
;
|
|
||||||
|
|
||||||
|
|
||||||
-- !query 6
|
-- !query 6
|
||||||
|
@ -102,5 +101,4 @@ number of columns in the output of subquery.
|
||||||
Left side columns:
|
Left side columns:
|
||||||
[t1.`t1a`, t1.`t1b`].
|
[t1.`t1a`, t1.`t1b`].
|
||||||
Right side columns:
|
Right side columns:
|
||||||
[t2.`t2a`].
|
[t2.`t2a`].;
|
||||||
;
|
|
||||||
|
|
Loading…
Reference in a new issue