From 558dd2360163250e9fb55c3a49f87c907b65ea0d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 27 Jul 2019 10:38:34 +0800 Subject: [PATCH] [SPARK-28441][SQL][PYTHON] Fix error when non-foldable expression is used in correlated scalar subquery ## What changes were proposed in this pull request? In SPARK-15370, We checked the expression at the root of the correlated subquery, in order to fix count bug. If a `PythonUDF` in in the checking path, evaluating it causes the failure as we can't statically evaluate `PythonUDF`. The Python UDF test added at SPARK-28277 shows this issue. If we can statically evaluate the expression, we intercept NULL values coming from the outer join and replace them with the value that the subquery's expression like before, if it is not, we replace them with the `PythonUDF` expression, with statically evaluated parameters. After this, the last query in `udf-except.sql` which throws `java.lang.UnsupportedOperationException` can be run: ``` SELECT t1.k FROM t1 WHERE t1.v <= (SELECT udf(max(udf(t2.v))) FROM t2 WHERE udf(t2.k) = udf(t1.k)) MINUS SELECT t1.k FROM t1 WHERE udf(t1.v) >= (SELECT min(udf(t2.v)) FROM t2 WHERE t2.k = t1.k) -- !query 2 schema struct -- !query 2 output two ``` Note that this issue is also for other non-foldable expressions, like rand. As like PythonUDF, we can't call `eval` on this kind of expressions in optimization. The evaluation needs to defer to query runtime. ## How was this patch tested? Added tests. Closes #25204 from viirya/SPARK-28441. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/optimizer/subquery.scala | 84 +++++-- .../org/apache/spark/sql/SubquerySuite.scala | 227 ++++++++++++++++++ 4 files changed, 291 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 021fb26bf7..5bf4dc1f04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2833,7 +2833,7 @@ object EliminateUnions extends Rule[LogicalPlan] { * rule can't work for those parameters. */ object CleanupAliases extends Rule[LogicalPlan] { - private def trimAliases(e: Expression): Expression = { + def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child case MultiAlias(child, _) => child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index af90ef4267..1c36cdcb00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -675,7 +675,9 @@ object ColumnPruning extends Rule[LogicalPlan] { */ private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) - if p2.outputSet.subsetOf(child.outputSet) => + if p2.outputSet.subsetOf(child.outputSet) && + // We only remove attribute-only project. + p2.projectList.forall(_.isInstanceOf[AttributeReference]) => p1.copy(child = f.copy(child = child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index e78ed1c3c5..4f7333c387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -317,24 +318,40 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } /** - * Statically evaluate an expression containing zero or more placeholders, given a set - * of bindings for placeholder values. + * Checks if given expression is foldable. Evaluates it and returns it as literal, if yes. + * If not, returns the original expression without evaluation. */ - private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { + private def tryEvalExpr(expr: Expression): Expression = { + // Removes Alias over given expression, because Alias is not foldable. + if (!CleanupAliases.trimAliases(expr).foldable) { + // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated. + // Needs to evaluate them on query runtime. + expr + } else { + Literal.create(expr.eval(), expr.dataType) + } + } + + /** + * Statically evaluate an expression containing zero or more placeholders, given a set + * of bindings for placeholder values, if the expression is evaluable. If it is not, + * bind statically evaluated expression results to an expression. + */ + private def bindingExpr( + expr: Expression, + bindings: Map[ExprId, Expression]): Expression = { val rewrittenExpr = expr transform { case r: AttributeReference => - bindings(r.exprId) match { - case Some(v) => Literal.create(v, r.dataType) - case None => Literal.default(NullType) - } + bindings.getOrElse(r.exprId, Literal.default(NullType)) } - Option(rewrittenExpr.eval()) + + tryEvalExpr(rewrittenExpr) } /** * Statically evaluate an expression containing one or more aggregates on an empty input. */ - private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { + private def evalAggOnZeroTups(expr: Expression) : Expression = { // AggregateExpressions are Unevaluable, so we need to replace all aggregates // in the expression with the value they would return for zero input tuples. // Also replace attribute refs (for example, for grouping columns) with NULL. @@ -344,7 +361,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case _: AttributeReference => Literal.default(NullType) } - Option(rewrittenExpr.eval()) + + tryEvalExpr(rewrittenExpr) } /** @@ -354,19 +372,33 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in * CheckAnalysis become less restrictive, this method will need to change. */ - private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { + private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = { // Inputs to this method will start with a chain of zero or more SubqueryAlias // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. - def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { + def evalPlan(lp : LogicalPlan) : Map[ExprId, Expression] = lp match { case SubqueryAlias(_, child) => evalPlan(child) case Filter(condition, child) => val bindings = evalPlan(child) - if (bindings.isEmpty) bindings - else { - val exprResult = evalExpr(condition, bindings).getOrElse(false) - .asInstanceOf[Boolean] - if (exprResult) bindings else Map.empty + if (bindings.isEmpty) { + bindings + } else { + val bindCondition = bindingExpr(condition, bindings) + + if (!bindCondition.foldable) { + // We can't evaluate the condition. Evaluate it in query runtime. + bindings.map { case (id, expr) => + val newExpr = If(bindCondition, expr, Literal.create(null, expr.dataType)) + (id, newExpr) + } + } else { + // The bound condition can be evaluated. + bindCondition.eval() match { + // For filter condition, null is the same as false. + case null | false => Map.empty + case true => bindings + } + } } case Project(projectList, child) => @@ -374,7 +406,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { if (bindings.isEmpty) { bindings } else { - projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap + projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap } case Aggregate(_, aggExprs, _) => @@ -382,8 +414,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // for joining with the outer query block. Fill those expressions in with // nulls and statically evaluate the remainder. aggExprs.map { - case ref: AttributeReference => (ref.exprId, None) - case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None) + case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType)) + case alias @ Alias(_: AttributeReference, _) => + (alias.exprId, Literal.create(null, alias.dataType)) case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap @@ -394,7 +427,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap.getOrElse(plan.output.head.exprId, None) + resultMap.get(plan.output.head.exprId) match { + case Some(Literal(null, _)) | None => None + case o => o + } } /** @@ -473,7 +509,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { currentChild.output :+ Alias( If(IsNull(alwaysTrueRef), - Literal.create(resultWithZeroTups.get, origOutput.dataType), + resultWithZeroTups.get, aggValRef), origOutput.name)(exprId = origOutput.exprId), Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), @@ -494,11 +530,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case op => sys.error(s"Unexpected operator $op in corelated subquery") } - // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups + // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) // ELSE (aggregate value) END AS (original column name) val caseExpr = Alias(CaseWhen(Seq( - (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), + (IsNull(alwaysTrueRef), resultWithZeroTups.get), (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), aggValRef), origOutput.name)(exprId = origOutput.exprId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index b2c3868407..4ec85b0ac6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1384,4 +1384,231 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")), "SubqueryExec name should start with scalar-subquery#") } + + test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) is null FROM r WHERE l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + } + + test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql("SELECT a, (SELECT udf(count(*)) FROM r WHERE l.a = r.c) AS cnt FROM l"), + Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) + :: Row(null, 0) :: Row(6, 1) :: Nil) + } + + test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT + | l.a AS grp_a + |FROM l GROUP BY l.a + |HAVING + | ( + | SELECT udf(count(*)) FROM r WHERE grp_a = r.c + | ) = 0 + |ORDER BY grp_a""".stripMargin), + Row(null) :: Row(1) :: Nil) + } + + test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT + | l.a AS aval, + | sum( + | ( + | SELECT udf(count(*)) FROM r WHERE l.a = r.c + | ) + | ) AS cnt + |FROM l GROUP BY l.a ORDER BY aval""".stripMargin), + Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) + } + + test("SPARK-28441: COUNT bug negative examples with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + // Case 1: Potential COUNT bug case that was working correctly prior to the fix + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) FROM r WHERE l.a = r.c) is null"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) + // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) > 0"), + Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) + // Case 3: COUNT inside aggregate expression but no COUNT bug. + checkAnswer( + sql(""" + |SELECT + | l.a + |FROM l + |WHERE + | ( + | SELECT udf(count(*)) + udf(sum(r.d)) + | FROM r WHERE l.a = r.c + | ) = 0""".stripMargin), + Nil) + } + + test("SPARK-28441: COUNT bug in nested subquery with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT l.a FROM l + |WHERE ( + | SELECT cntPlusOne + 1 AS cntPlusTwo FROM ( + | SELECT cnt + 1 AS cntPlusOne FROM ( + | SELECT udf(sum(r.c)) s, udf(count(*)) cnt FROM r WHERE l.a = r.c + | HAVING cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT + | l.a + |FROM l WHERE + | ( + | SELECT CASE WHEN udf(count(*)) = 1 THEN null ELSE udf(count(*)) END AS cnt + | FROM r WHERE l.a = r.c + | ) = 0""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql( + """ + |SELECT + | l.b, + | ( + | SELECT (r.c + udf(count(*))) is null + | FROM r + | WHERE l.a = r.c GROUP BY r.c + | ) + |FROM l + """.stripMargin), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: + Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + } + + test("SPARK-28441: COUNT bug with non-foldable expression") { + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " + + "WHERE l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " + + "WHERE l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT sum(r.d) is null from r " + + "WHERE l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + } + + test("SPARK-28441: COUNT bug in nested subquery with non-foldable expr") { + checkAnswer( + sql(""" + |SELECT l.a FROM l + |WHERE ( + | SELECT cntPlusOne + 1 AS cntPlusTwo FROM ( + | SELECT cnt + 1 AS cntPlusOne FROM ( + | SELECT sum(r.c) s, (count(*) + cast(rand() as int)) cnt FROM r + | WHERE l.a = r.c HAVING cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-28441: COUNT bug with non-foldable expression in Filter condition") { + val df = sql(""" + |SELECT + | l.a + |FROM l WHERE + | ( + | SELECT cntPlusOne + 1 as cntPlusTwo FROM + | ( + | SELECT cnt + 1 as cntPlusOne FROM + | ( + | SELECT sum(r.c) s, count(*) cnt FROM r WHERE l.a = r.c HAVING cnt > 0 + | ) + | ) + | ) = 2""".stripMargin) + val df2 = sql(""" + |SELECT + | l.a + |FROM l WHERE + | ( + | SELECT cntPlusOne + 1 AS cntPlusTwo + | FROM + | ( + | SELECT cnt + 1 AS cntPlusOne + | FROM + | ( + | SELECT sum(r.c) s, count(*) cnt FROM r + | WHERE l.a = r.c HAVING (cnt + cast(rand() as int)) > 0 + | ) + | ) + | ) = 2""".stripMargin) + checkAnswer(df, df2) + checkAnswer(df, Nil) + } }