[SPARK-28441][SQL][PYTHON] Fix error when non-foldable expression is used in correlated scalar subquery

## What changes were proposed in this pull request?

In SPARK-15370, We checked the expression at the root of the correlated subquery, in order to fix count bug. If a `PythonUDF` in in the checking path, evaluating it causes the failure as we can't statically evaluate `PythonUDF`. The Python UDF test added at SPARK-28277 shows this issue.

If we can statically evaluate the expression, we intercept NULL values coming from the outer join and replace them with the value that the subquery's expression like before, if it is not, we replace them with the `PythonUDF` expression, with statically evaluated parameters.

After this, the last query in `udf-except.sql` which throws `java.lang.UnsupportedOperationException` can be run:

```
SELECT t1.k
FROM   t1
WHERE  t1.v <= (SELECT   udf(max(udf(t2.v)))
                FROM     t2
                WHERE    udf(t2.k) = udf(t1.k))
MINUS
SELECT t1.k
FROM   t1
WHERE  udf(t1.v) >= (SELECT   min(udf(t2.v))
                FROM     t2
                WHERE    t2.k = t1.k)
-- !query 2 schema
struct<k:string>
-- !query 2 output
two
```

Note that this issue is also for other non-foldable expressions, like rand. As like PythonUDF, we can't call `eval` on this kind of expressions in optimization. The evaluation needs to defer to query runtime.

## How was this patch tested?

Added tests.

Closes #25204 from viirya/SPARK-28441.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Liang-Chi Hsieh 2019-07-27 10:38:34 +08:00 committed by Wenchen Fan
parent 836a8ff2b9
commit 558dd23601
4 changed files with 291 additions and 26 deletions

View file

@ -2833,7 +2833,7 @@ object EliminateUnions extends Rule[LogicalPlan] {
* rule can't work for those parameters.
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
def trimAliases(e: Expression): Expression = {
e.transformDown {
case Alias(child, _) => child
case MultiAlias(child, _) => child

View file

@ -675,7 +675,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
*/
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
if p2.outputSet.subsetOf(child.outputSet) =>
if p2.outputSet.subsetOf(child.outputSet) &&
// We only remove attribute-only project.
p2.projectList.forall(_.isInstanceOf[AttributeReference]) =>
p1.copy(child = f.copy(child = child))
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@ -317,24 +318,40 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
}
/**
* Statically evaluate an expression containing zero or more placeholders, given a set
* of bindings for placeholder values.
* Checks if given expression is foldable. Evaluates it and returns it as literal, if yes.
* If not, returns the original expression without evaluation.
*/
private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
private def tryEvalExpr(expr: Expression): Expression = {
// Removes Alias over given expression, because Alias is not foldable.
if (!CleanupAliases.trimAliases(expr).foldable) {
// SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated.
// Needs to evaluate them on query runtime.
expr
} else {
Literal.create(expr.eval(), expr.dataType)
}
}
/**
* Statically evaluate an expression containing zero or more placeholders, given a set
* of bindings for placeholder values, if the expression is evaluable. If it is not,
* bind statically evaluated expression results to an expression.
*/
private def bindingExpr(
expr: Expression,
bindings: Map[ExprId, Expression]): Expression = {
val rewrittenExpr = expr transform {
case r: AttributeReference =>
bindings(r.exprId) match {
case Some(v) => Literal.create(v, r.dataType)
case None => Literal.default(NullType)
}
bindings.getOrElse(r.exprId, Literal.default(NullType))
}
Option(rewrittenExpr.eval())
tryEvalExpr(rewrittenExpr)
}
/**
* Statically evaluate an expression containing one or more aggregates on an empty input.
*/
private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
private def evalAggOnZeroTups(expr: Expression) : Expression = {
// AggregateExpressions are Unevaluable, so we need to replace all aggregates
// in the expression with the value they would return for zero input tuples.
// Also replace attribute refs (for example, for grouping columns) with NULL.
@ -344,7 +361,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case _: AttributeReference => Literal.default(NullType)
}
Option(rewrittenExpr.eval())
tryEvalExpr(rewrittenExpr)
}
/**
@ -354,19 +372,33 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
* CheckAnalysis become less restrictive, this method will need to change.
*/
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = {
// Inputs to this method will start with a chain of zero or more SubqueryAlias
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
def evalPlan(lp : LogicalPlan) : Map[ExprId, Expression] = lp match {
case SubqueryAlias(_, child) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
else {
val exprResult = evalExpr(condition, bindings).getOrElse(false)
.asInstanceOf[Boolean]
if (exprResult) bindings else Map.empty
if (bindings.isEmpty) {
bindings
} else {
val bindCondition = bindingExpr(condition, bindings)
if (!bindCondition.foldable) {
// We can't evaluate the condition. Evaluate it in query runtime.
bindings.map { case (id, expr) =>
val newExpr = If(bindCondition, expr, Literal.create(null, expr.dataType))
(id, newExpr)
}
} else {
// The bound condition can be evaluated.
bindCondition.eval() match {
// For filter condition, null is the same as false.
case null | false => Map.empty
case true => bindings
}
}
}
case Project(projectList, child) =>
@ -374,7 +406,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
if (bindings.isEmpty) {
bindings
} else {
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
}
case Aggregate(_, aggExprs, _) =>
@ -382,8 +414,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
// for joining with the outer query block. Fill those expressions in with
// nulls and statically evaluate the remainder.
aggExprs.map {
case ref: AttributeReference => (ref.exprId, None)
case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None)
case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
case alias @ Alias(_: AttributeReference, _) =>
(alias.exprId, Literal.create(null, alias.dataType))
case ne => (ne.exprId, evalAggOnZeroTups(ne))
}.toMap
@ -394,7 +427,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
val resultMap = evalPlan(plan)
// By convention, the scalar subquery result is the leftmost field.
resultMap.getOrElse(plan.output.head.exprId, None)
resultMap.get(plan.output.head.exprId) match {
case Some(Literal(null, _)) | None => None
case o => o
}
}
/**
@ -473,7 +509,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
Literal.create(resultWithZeroTups.get, origOutput.dataType),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = origOutput.exprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
@ -494,11 +530,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}
// CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = origOutput.exprId)

View file

@ -1384,4 +1384,231 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")),
"SubqueryExec name should start with scalar-subquery#")
}
test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
// Case 1: Canonical example of the COUNT bug
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) < l.a"),
Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
// Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
// a rewrite that is vulnerable to the COUNT bug
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
// Case 3: COUNT bug without a COUNT aggregate
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) is null FROM r WHERE l.a = r.c)"),
Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
}
test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
checkAnswer(
sql("SELECT a, (SELECT udf(count(*)) FROM r WHERE l.a = r.c) AS cnt FROM l"),
Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
:: Row(null, 0) :: Row(6, 1) :: Nil)
}
test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
checkAnswer(
sql("""
|SELECT
| l.a AS grp_a
|FROM l GROUP BY l.a
|HAVING
| (
| SELECT udf(count(*)) FROM r WHERE grp_a = r.c
| ) = 0
|ORDER BY grp_a""".stripMargin),
Row(null) :: Row(1) :: Nil)
}
test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
checkAnswer(
sql("""
|SELECT
| l.a AS aval,
| sum(
| (
| SELECT udf(count(*)) FROM r WHERE l.a = r.c
| )
| ) AS cnt
|FROM l GROUP BY l.a ORDER BY aval""".stripMargin),
Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil)
}
test("SPARK-28441: COUNT bug negative examples with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
// Case 1: Potential COUNT bug case that was working correctly prior to the fix
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) FROM r WHERE l.a = r.c) is null"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
// Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) > 0"),
Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
// Case 3: COUNT inside aggregate expression but no COUNT bug.
checkAnswer(
sql("""
|SELECT
| l.a
|FROM l
|WHERE
| (
| SELECT udf(count(*)) + udf(sum(r.d))
| FROM r WHERE l.a = r.c
| ) = 0""".stripMargin),
Nil)
}
test("SPARK-28441: COUNT bug in nested subquery with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
checkAnswer(
sql("""
|SELECT l.a FROM l
|WHERE (
| SELECT cntPlusOne + 1 AS cntPlusTwo FROM (
| SELECT cnt + 1 AS cntPlusOne FROM (
| SELECT udf(sum(r.c)) s, udf(count(*)) cnt FROM r WHERE l.a = r.c
| HAVING cnt = 0
| )
| )
|) = 2""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}
test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
checkAnswer(
sql("""
|SELECT
| l.a
|FROM l WHERE
| (
| SELECT CASE WHEN udf(count(*)) = 1 THEN null ELSE udf(count(*)) END AS cnt
| FROM r WHERE l.a = r.c
| ) = 0""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}
test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") {
import IntegratedUDFTestUtils._
val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
checkAnswer(
sql(
"""
|SELECT
| l.b,
| (
| SELECT (r.c + udf(count(*))) is null
| FROM r
| WHERE l.a = r.c GROUP BY r.c
| )
|FROM l
""".stripMargin),
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}
test("SPARK-28441: COUNT bug with non-foldable expression") {
// Case 1: Canonical example of the COUNT bug
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " +
"WHERE l.a = r.c) < l.a"),
Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
// Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
// a rewrite that is vulnerable to the COUNT bug
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " +
"WHERE l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
// Case 3: COUNT bug without a COUNT aggregate
checkAnswer(
sql("SELECT l.a FROM l WHERE (SELECT sum(r.d) is null from r " +
"WHERE l.a = r.c)"),
Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
}
test("SPARK-28441: COUNT bug in nested subquery with non-foldable expr") {
checkAnswer(
sql("""
|SELECT l.a FROM l
|WHERE (
| SELECT cntPlusOne + 1 AS cntPlusTwo FROM (
| SELECT cnt + 1 AS cntPlusOne FROM (
| SELECT sum(r.c) s, (count(*) + cast(rand() as int)) cnt FROM r
| WHERE l.a = r.c HAVING cnt = 0
| )
| )
|) = 2""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}
test("SPARK-28441: COUNT bug with non-foldable expression in Filter condition") {
val df = sql("""
|SELECT
| l.a
|FROM l WHERE
| (
| SELECT cntPlusOne + 1 as cntPlusTwo FROM
| (
| SELECT cnt + 1 as cntPlusOne FROM
| (
| SELECT sum(r.c) s, count(*) cnt FROM r WHERE l.a = r.c HAVING cnt > 0
| )
| )
| ) = 2""".stripMargin)
val df2 = sql("""
|SELECT
| l.a
|FROM l WHERE
| (
| SELECT cntPlusOne + 1 AS cntPlusTwo
| FROM
| (
| SELECT cnt + 1 AS cntPlusOne
| FROM
| (
| SELECT sum(r.c) s, count(*) cnt FROM r
| WHERE l.a = r.c HAVING (cnt + cast(rand() as int)) > 0
| )
| )
| ) = 2""".stripMargin)
checkAnswer(df, df2)
checkAnswer(df, Nil)
}
}