From 6ed2dfbba193d29436dccae4c379dae7b5ba5bdb Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 28 Apr 2020 08:11:41 +0000 Subject: [PATCH] [SPARK-31519][SQL] Cast in having aggregate expressions returns the wrong result ### What changes were proposed in this pull request? Add a new logical node AggregateWithHaving, and the parser should create this plan for HAVING. The analyzer resolves it to Filter(..., Aggregate(...)). ### Why are the changes needed? The SQL parser in Spark creates Filter(..., Aggregate(...)) for the HAVING query, and Spark has a special analyzer rule ResolveAggregateFunctions to resolve the aggregate functions and grouping columns in the Filter operator. It works for simple cases in a very tricky way as it relies on rule execution order: 1. Rule ResolveReferences hits the Aggregate operator and resolves attributes inside aggregate functions, but the function itself is still unresolved as it's an UnresolvedFunction. This stops resolving the Filter operator as the child Aggrege operator is still unresolved. 2. Rule ResolveFunctions resolves UnresolvedFunction. This makes the Aggrege operator resolved. 3. Rule ResolveAggregateFunctions resolves the Filter operator if its child is a resolved Aggregate. This rule can correctly resolve the grouping columns. In the example query, I put a CAST, which needs to be resolved by rule ResolveTimeZone, which runs after ResolveAggregateFunctions. This breaks step 3 as the Aggregate operator is unresolved at that time. Then the analyzer starts next round and the Filter operator is resolved by ResolveReferences, which wrongly resolves the grouping columns. See the demo below: ``` SELECT SUM(a) AS b, '2020-01-01' AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 ``` The query's result is ``` +---+----------+ | b| fake| +---+----------+ | 2|2020-01-01| +---+----------+ ``` But if we add CAST, it will return an empty result. ``` SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 ``` ### Does this PR introduce any user-facing change? Yes, bug fix for cast in having aggregate expressions. ### How was this patch tested? New UT added. Closes #28294 from xuanyuanking/SPARK-31519. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 133 ++++++++++-------- .../sql/catalyst/analysis/unresolved.scala | 13 +- .../spark/sql/catalyst/dsl/package.scala | 8 ++ .../sql/catalyst/parser/AstBuilder.scala | 7 +- .../sql/catalyst/parser/PlanParserSuite.scala | 5 +- .../resources/sql-tests/inputs/having.sql | 3 + .../sql-tests/results/having.sql.out | 10 +- .../results/postgreSQL/window_part3.sql.out | 6 +- .../sql/DataFrameWindowFunctionsSuite.scala | 30 ++-- 9 files changed, 135 insertions(+), 80 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7a2b4e63e1..9f58c165d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1239,13 +1239,13 @@ class Analyzer( /** * Resolves the attribute and extract value expressions(s) by traversing the * input expression in top down manner. The traversal is done in top-down manner as - * we need to skip over unbound lamda function expression. The lamda expressions are + * we need to skip over unbound lambda function expression. The lambda expressions are * resolved in a different rule [[ResolveLambdaVariables]] * * Example : * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" * - * In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]] + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]] * * Note : In this routine, the unresolved attributes are resolved from the input plan's * children attributes. @@ -1400,6 +1400,9 @@ class Analyzer( notMatchedActions = newNotMatchedActions) } + // Skip the having clause here, this will be handled in ResolveAggregateFunctions. + case h: AggregateWithHaving => h + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) @@ -2040,62 +2043,14 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => + // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly + // resolve the having condition expression, here we skip resolving it in ResolveReferences + // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved => + resolveHaving(Filter(cond, agg), agg) - // Try resolving the condition of the filter as though it is in the aggregate clause - try { - val aggregatedCondition = - Aggregate( - grouping, - Alias(cond, "havingCondition")() :: Nil, - child) - val resolvedOperator = executeSameContext(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved) { - // Try to replace all aggregate expressions in the filter by an alias. - val aggregateExpressions = ArrayBuffer.empty[NamedExpression] - val transformedAggregateFilter = resolvedAggregateFilter.transform { - case ae: AggregateExpression => - val alias = Alias(ae, ae.toString)() - aggregateExpressions += alias - alias.toAttribute - // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. - case e: Expression if grouping.exists(_.semanticEquals(e)) && - !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !agg.output.exists(_.semanticEquals(e)) => - e match { - case ne: NamedExpression => - aggregateExpressions += ne - ne.toAttribute - case _ => - val alias = Alias(e, e.toString)() - aggregateExpressions += alias - alias.toAttribute - } - } - - // Push the aggregate expressions into the aggregate (if any). - if (aggregateExpressions.nonEmpty) { - Project(agg.output, - Filter(transformedAggregateFilter, - agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) - } else { - f - } - } else { - f - } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => f - } + case f @ Filter(_, agg: Aggregate) if agg.resolved => + resolveHaving(f, agg) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => @@ -2166,6 +2121,63 @@ class Analyzer( def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } + + def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { + // Try resolving the condition of the filter as though it is in the aggregate clause + try { + val aggregatedCondition = + Aggregate( + agg.groupingExpressions, + Alias(filter.condition, "havingCondition")() :: Nil, + agg.child) + val resolvedOperator = executeSameContext(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved) { + // Try to replace all aggregate expressions in the filter by an alias. + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformedAggregateFilter = resolvedAggregateFilter.transform { + case ae: AggregateExpression => + val alias = Alias(ae, ae.toString)() + aggregateExpressions += alias + alias.toAttribute + // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. + case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) && + !ResolveGroupingAnalytics.hasGroupingFunction(e) && + !agg.output.exists(_.semanticEquals(e)) => + e match { + case ne: NamedExpression => + aggregateExpressions += ne + ne.toAttribute + case _ => + val alias = Alias(e, e.toString)() + aggregateExpressions += alias + alias.toAttribute + } + } + + // Push the aggregate expressions into the aggregate (if any). + if (aggregateExpressions.nonEmpty) { + Project(agg.output, + Filter(transformedAggregateFilter, + agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + } else { + filter + } + } else { + filter + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => filter + } + } } /** @@ -2590,11 +2602,14 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Filter(condition, _) if hasWindowFunction(condition) => - failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + failAnalysis("It is not allowed to use window functions inside WHERE clause") + + case AggregateWithHaving(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside HAVING clause") // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. - case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) + case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6048d98033..806cdeb95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} @@ -538,3 +538,14 @@ case class UnresolvedOrdinal(ordinal: Int) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +/** + * Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter. + */ +case class AggregateWithHaving( + havingCondition: Expression, + child: Aggregate) + extends UnaryNode { + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index cc96d905a8..f135f50493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -364,6 +364,14 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } + def having( + groupingExprs: Expression*)( + aggregateExprs: Expression*)( + havingCondition: Expression): LogicalPlan = { + AggregateWithHaving(havingCondition, + groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate]) + } + def window( windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e51b8f3b42..146df97d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case p: Predicate => p case e => Cast(e, BooleanType) } - Filter(predicate, plan) + plan match { + case aggregate: Aggregate => + AggregateWithHaving(predicate, aggregate) + case _ => + Filter(predicate, plan) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 0a591ad9cd..bec35ae458 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -208,7 +208,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) assertEqual( "select a, b from db.c having x < 1", - table("db", "c").groupBy()('a, 'b).where('x < 1)) + table("db", "c").having()('a, 'b)('x < 1)) assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) @@ -574,8 +574,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select g from t group by g having a > (select b from s)", table("t") - .groupBy('g)('g) - .where('a > ScalarSubquery(table("s").select('b)))) + .having('g)('g)('a > ScalarSubquery(table("s").select('b)))) } test("table reference") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 868a911e78..6868b59029 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); -- SPARK-20329: make sure we handle timezones correctly SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; + +-- SPARK-31519: Cast in having aggregate expressions returns the wrong result +SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index 5bd185d7b8..aa8ff73723 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 6 -- !query @@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint> -- !query output 3 7 + + +-- !query +SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 +-- !query schema +struct +-- !query output +2 2020-01-01 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out index acce688092..08eba6797b 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out @@ -294,7 +294,7 @@ SELECT * FROM empsalary WHERE row_number() OVER (ORDER BY salary) < 10 struct<> -- !query output org.apache.spark.sql.AnalysisException -It is not allowed to use window functions inside WHERE and HAVING clauses; +It is not allowed to use window functions inside WHERE clause; -- !query @@ -341,7 +341,7 @@ SELECT * FROM empsalary WHERE (rank() OVER (ORDER BY random())) > 10 struct<> -- !query output org.apache.spark.sql.AnalysisException -It is not allowed to use window functions inside WHERE and HAVING clauses; +It is not allowed to use window functions inside WHERE clause; -- !query @@ -350,7 +350,7 @@ SELECT * FROM empsalary WHERE rank() OVER (ORDER BY random()) struct<> -- !query output org.apache.spark.sql.AnalysisException -It is not allowed to use window functions inside WHERE and HAVING clauses; +It is not allowed to use window functions inside WHERE clause; -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index d398657ec0..f72ccaa63b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -665,40 +665,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest } test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { - def checkAnalysisError(df: => DataFrame): Unit = { + def checkAnalysisError(df: => DataFrame, clause: String): Unit = { val thrownException = the[AnalysisException] thrownBy { df.queryExecution.analyzed } - assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + assert(thrownException.message.contains(s"window functions inside $clause clause")) } - checkAnalysisError(testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1)) - checkAnalysisError(testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1)) + checkAnalysisError( + testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE") + checkAnalysisError( + testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE") checkAnalysisError( testData2.groupBy($"a") .agg(avg($"b").as("avgb")) - .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1)) + .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE") checkAnalysisError( testData2.groupBy($"a") .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where(rank().over(Window.orderBy($"a")) === 1)) + .where(rank().over(Window.orderBy($"a")) === 1), "WHERE") checkAnalysisError( testData2.groupBy($"a") .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1)) + .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE") - checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) - checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE") checkAnalysisError( - sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE") checkAnalysisError( - sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"), + "HAVING") + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"), + "HAVING") checkAnalysisError( sql( s"""SELECT a, MAX(b) |FROM testData2 |GROUP BY a - |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin), + "HAVING") } test("window functions in multiple selects") {