[SPARK-15832][SQL] Embedded IN/EXISTS predicate subquery throws TreeNodeException
## What changes were proposed in this pull request? Queries with embedded existential sub-query predicates throws exception when building the physical plan. Example failing query: ```SQL scala> Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") scala> Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") scala> sql("select c1 from t1 where (case when c2 in (select c2 from t2) then 2 else 3 end) IN (select c2 from t1)").show() Binding attribute, tree: c2#239 org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: c2#239 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88) ... at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:87) at org.apache.spark.sql.execution.joins.HashJoin$$anonfun$4.apply(HashJoin.scala:66) at org.apache.spark.sql.execution.joins.HashJoin$$anonfun$4.apply(HashJoin.scala:66) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.immutable.List.map(List.scala:285) at org.apache.spark.sql.execution.joins.HashJoin$class.org$apache$spark$sql$execution$joins$HashJoin$$x$8(HashJoin.scala:66) at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.org$apache$spark$sql$execution$joins$HashJoin$$x$8$lzycompute(BroadcastHashJoinExec.scala:38) at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.org$apache$spark$sql$execution$joins$HashJoin$$x$8(BroadcastHashJoinExec.scala:38) at org.apache.spark.sql.execution.joins.HashJoin$class.buildKeys(HashJoin.scala:63) at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.buildKeys$lzycompute(BroadcastHashJoinExec.scala:38) at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.buildKeys(BroadcastHashJoinExec.scala:38) at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.requiredChildDistribution(BroadcastHashJoinExec.scala:52) ``` **Problem description:** When the left hand side expression of an existential sub-query predicate contains another embedded sub-query predicate, the RewritePredicateSubquery optimizer rule does not resolve the embedded sub-query expressions into existential joins.For example, the above query has the following optimized plan, which fails during physical plan build. ```SQL == Optimized Logical Plan == Project [_1#224 AS c1#227] +- Join LeftSemi, (CASE WHEN predicate-subquery#255 [(_2#225 = c2#239)] THEN 2 ELSE 3 END = c2#228#262) : +- SubqueryAlias predicate-subquery#255 [(_2#225 = c2#239)] : +- LocalRelation [c2#239] :- LocalRelation [_1#224, _2#225] +- LocalRelation [c2#228#262] == Physical Plan == org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: c2#239 ``` **Solution:** In RewritePredicateSubquery, before rewriting the outermost predicate sub-query, resolve any embedded existential sub-queries. The Optimized plan for the above query after the changes looks like below. ```SQL == Optimized Logical Plan == Project [_1#224 AS c1#227] +- Join LeftSemi, (CASE WHEN exists#285 THEN 2 ELSE 3 END = c2#228#284) :- Join ExistenceJoin(exists#285), (_2#225 = c2#239) : :- LocalRelation [_1#224, _2#225] : +- LocalRelation [c2#239] +- LocalRelation [c2#228#284] == Physical Plan == *Project [_1#224 AS c1#227] +- *BroadcastHashJoin [CASE WHEN exists#285 THEN 2 ELSE 3 END], [c2#228#284], LeftSemi, BuildRight :- *BroadcastHashJoin [_2#225], [c2#239], ExistenceJoin(exists#285), BuildRight : :- LocalTableScan [_1#224, _2#225] : +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) : +- LocalTableScan [c2#239] +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) +- LocalTableScan [c2#228#284] +- LocalTableScan [c222#36], [[111],[222]] ``` ## How was this patch tested? Added new test cases in SubquerySuite.scala Author: Ioana Delaney <ioanamdelaney@gmail.com> Closes #13570 from ioana-delaney/fixEmbedSubPredV1.
This commit is contained in:
parent
9770f6ee60
commit
0ff8a68b9f
|
@ -1716,31 +1716,52 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
|
|||
// Filter the plan by applying left semi and left anti joins.
|
||||
withSubquery.foldLeft(newFilter) {
|
||||
case (p, PredicateSubquery(sub, conditions, _, _)) =>
|
||||
Join(p, sub, LeftSemi, conditions.reduceOption(And))
|
||||
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceOption(And), p)
|
||||
Join(outerPlan, sub, LeftSemi, joinCond)
|
||||
case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
|
||||
Join(p, sub, LeftAnti, conditions.reduceOption(And))
|
||||
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceOption(And), p)
|
||||
Join(outerPlan, sub, LeftAnti, joinCond)
|
||||
case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
|
||||
// This is a NULL-aware (left) anti join (NAAJ).
|
||||
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
|
||||
// Construct the condition. A NULL in one of the conditions is regarded as a positive
|
||||
// result; such a row will be filtered out by the Anti-Join operator.
|
||||
val anyNull = conditions.map(IsNull).reduceLeft(Or)
|
||||
val condition = conditions.reduceLeft(And)
|
||||
|
||||
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
|
||||
// if performance matters to you.
|
||||
Join(p, sub, LeftAnti, Option(Or(anyNull, condition)))
|
||||
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
|
||||
// Use EXISTS if performance matters to you.
|
||||
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceLeftOption(And), p)
|
||||
val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or)
|
||||
Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get)))
|
||||
case (p, predicate) =>
|
||||
var joined = p
|
||||
val replaced = predicate transformUp {
|
||||
case PredicateSubquery(sub, conditions, nullAware, _) =>
|
||||
// TODO: support null-aware join
|
||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||
joined = Join(joined, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
|
||||
exists
|
||||
}
|
||||
Project(p.output, Filter(replaced, joined))
|
||||
val (newCond, inputPlan) = rewriteExistentialExpr(Option(predicate), p)
|
||||
Project(p.output, Filter(newCond.get, inputPlan))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a predicate expression and an input plan, it rewrites
|
||||
* any embedded existential sub-query into an existential join.
|
||||
* It returns the rewritten expression together with the updated plan.
|
||||
* Currently, it does not support null-aware joins. Embedded NOT IN predicates
|
||||
* are blocked in the Analyzer.
|
||||
*/
|
||||
private def rewriteExistentialExpr(
|
||||
expr: Option[Expression],
|
||||
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
|
||||
var newPlan = plan
|
||||
expr match {
|
||||
case Some(e) =>
|
||||
val newExpr = e transformUp {
|
||||
case PredicateSubquery(sub, conditions, nullAware, _) =>
|
||||
// TODO: support null-aware join
|
||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
|
||||
exists
|
||||
}
|
||||
(Option(newExpr), newPlan)
|
||||
case None =>
|
||||
(expr, plan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -266,6 +266,172 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
|
|||
Row(null) :: Row(1) :: Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-15832: Test embedded existential predicate sub-queries") {
|
||||
withTempTable("t1", "t2", "t3", "t4", "t5") {
|
||||
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
|
||||
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
|
||||
Seq((1, 1), (2, 2), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t3")
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where c2 IN (select c2 from t2)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where c2 NOT IN (select c2 from t2)
|
||||
|
|
||||
""".stripMargin),
|
||||
Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where EXISTS (select c2 from t2)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where NOT EXISTS (select c2 from t2)
|
||||
|
|
||||
""".stripMargin),
|
||||
Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where NOT EXISTS (select c2 from t2) and
|
||||
| c2 IN (select c2 from t3)
|
||||
|
|
||||
""".stripMargin),
|
||||
Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (case when c2 IN (select 1 as one) then 1
|
||||
| else 2 end) = c1
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (case when c2 IN (select 1 as one) then 1
|
||||
| else 2 end)
|
||||
| IN (select c2 from t2)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (case when c2 IN (select c2 from t2) then 1
|
||||
| else 2 end)
|
||||
| IN (select c2 from t3)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (case when c2 IN (select c2 from t2) then 1
|
||||
| when c2 IN (select c2 from t3) then 2
|
||||
| else 3 end)
|
||||
| IN (select c2 from t1)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (c1, (case when c2 IN (select c2 from t2) then 1
|
||||
| when c2 IN (select c2 from t3) then 2
|
||||
| else 3 end))
|
||||
| IN (select c1, c2 from t1)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t3
|
||||
| where ((case when c2 IN (select c2 from t2) then 1 else 2 end),
|
||||
| (case when c2 IN (select c2 from t3) then 2 else 3 end))
|
||||
| IN (select c1, c2 from t3)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Row(1) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where ((case when EXISTS (select c2 from t2) then 1 else 2 end),
|
||||
| (case when c2 IN (select c2 from t3) then 2 else 3 end))
|
||||
| IN (select c1, c2 from t3)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (case when c2 IN (select c2 from t2) then 3
|
||||
| else 2 end)
|
||||
| NOT IN (select c2 from t3)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where ((case when c2 IN (select c2 from t2) then 1 else 2 end),
|
||||
| (case when NOT EXISTS (select c2 from t3) then 2
|
||||
| when EXISTS (select c2 from t2) then 3
|
||||
| else 3 end))
|
||||
| NOT IN (select c1, c2 from t3)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
| select c1 from t1
|
||||
| where (select max(c1) from t2 where c2 IN (select c2 from t3))
|
||||
| IN (select c2 from t2)
|
||||
|
|
||||
""".stripMargin),
|
||||
Row(1) :: Row(2) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("correlated scalar subquery in where") {
|
||||
checkAnswer(
|
||||
sql("select * from l where b < (select max(d) from r where a = c)"),
|
||||
|
|
Loading…
Reference in a new issue