diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7a2b4e63e1..9f58c165d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1239,13 +1239,13 @@ class Analyzer( /** * Resolves the attribute and extract value expressions(s) by traversing the * input expression in top down manner. The traversal is done in top-down manner as - * we need to skip over unbound lamda function expression. The lamda expressions are + * we need to skip over unbound lambda function expression. The lambda expressions are * resolved in a different rule [[ResolveLambdaVariables]] * * Example : * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" * - * In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]] + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]] * * Note : In this routine, the unresolved attributes are resolved from the input plan's * children attributes. @@ -1400,6 +1400,9 @@ class Analyzer( notMatchedActions = newNotMatchedActions) } + // Skip the having clause here, this will be handled in ResolveAggregateFunctions. + case h: AggregateWithHaving => h + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) @@ -2040,62 +2043,14 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => + // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly + // resolve the having condition expression, here we skip resolving it in ResolveReferences + // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved => + resolveHaving(Filter(cond, agg), agg) - // Try resolving the condition of the filter as though it is in the aggregate clause - try { - val aggregatedCondition = - Aggregate( - grouping, - Alias(cond, "havingCondition")() :: Nil, - child) - val resolvedOperator = executeSameContext(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved) { - // Try to replace all aggregate expressions in the filter by an alias. - val aggregateExpressions = ArrayBuffer.empty[NamedExpression] - val transformedAggregateFilter = resolvedAggregateFilter.transform { - case ae: AggregateExpression => - val alias = Alias(ae, ae.toString)() - aggregateExpressions += alias - alias.toAttribute - // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. - case e: Expression if grouping.exists(_.semanticEquals(e)) && - !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !agg.output.exists(_.semanticEquals(e)) => - e match { - case ne: NamedExpression => - aggregateExpressions += ne - ne.toAttribute - case _ => - val alias = Alias(e, e.toString)() - aggregateExpressions += alias - alias.toAttribute - } - } - - // Push the aggregate expressions into the aggregate (if any). - if (aggregateExpressions.nonEmpty) { - Project(agg.output, - Filter(transformedAggregateFilter, - agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) - } else { - f - } - } else { - f - } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => f - } + case f @ Filter(_, agg: Aggregate) if agg.resolved => + resolveHaving(f, agg) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => @@ -2166,6 +2121,63 @@ class Analyzer( def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } + + def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { + // Try resolving the condition of the filter as though it is in the aggregate clause + try { + val aggregatedCondition = + Aggregate( + agg.groupingExpressions, + Alias(filter.condition, "havingCondition")() :: Nil, + agg.child) + val resolvedOperator = executeSameContext(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved) { + // Try to replace all aggregate expressions in the filter by an alias. + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformedAggregateFilter = resolvedAggregateFilter.transform { + case ae: AggregateExpression => + val alias = Alias(ae, ae.toString)() + aggregateExpressions += alias + alias.toAttribute + // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. + case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) && + !ResolveGroupingAnalytics.hasGroupingFunction(e) && + !agg.output.exists(_.semanticEquals(e)) => + e match { + case ne: NamedExpression => + aggregateExpressions += ne + ne.toAttribute + case _ => + val alias = Alias(e, e.toString)() + aggregateExpressions += alias + alias.toAttribute + } + } + + // Push the aggregate expressions into the aggregate (if any). + if (aggregateExpressions.nonEmpty) { + Project(agg.output, + Filter(transformedAggregateFilter, + agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + } else { + filter + } + } else { + filter + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => filter + } + } } /** @@ -2590,11 +2602,14 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Filter(condition, _) if hasWindowFunction(condition) => - failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + failAnalysis("It is not allowed to use window functions inside WHERE clause") + + case AggregateWithHaving(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside HAVING clause") // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. - case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) + case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6048d98033..806cdeb95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} @@ -538,3 +538,14 @@ case class UnresolvedOrdinal(ordinal: Int) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +/** + * Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter. + */ +case class AggregateWithHaving( + havingCondition: Expression, + child: Aggregate) + extends UnaryNode { + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index cc96d905a8..f135f50493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -364,6 +364,14 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } + def having( + groupingExprs: Expression*)( + aggregateExprs: Expression*)( + havingCondition: Expression): LogicalPlan = { + AggregateWithHaving(havingCondition, + groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate]) + } + def window( windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e51b8f3b42..146df97d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case p: Predicate => p case e => Cast(e, BooleanType) } - Filter(predicate, plan) + plan match { + case aggregate: Aggregate => + AggregateWithHaving(predicate, aggregate) + case _ => + Filter(predicate, plan) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 0a591ad9cd..bec35ae458 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -208,7 +208,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) assertEqual( "select a, b from db.c having x < 1", - table("db", "c").groupBy()('a, 'b).where('x < 1)) + table("db", "c").having()('a, 'b)('x < 1)) assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) @@ -574,8 +574,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select g from t group by g having a > (select b from s)", table("t") - .groupBy('g)('g) - .where('a > ScalarSubquery(table("s").select('b)))) + .having('g)('g)('a > ScalarSubquery(table("s").select('b)))) } test("table reference") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 868a911e78..6868b59029 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); -- SPARK-20329: make sure we handle timezones correctly SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; + +-- SPARK-31519: Cast in having aggregate expressions returns the wrong result +SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index 5bd185d7b8..aa8ff73723 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 6 -- !query @@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint> -- !query output 3 7 + + +-- !query +SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 +-- !query schema +struct +-- !query output +2 2020-01-01 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out index acce688092..08eba6797b 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out @@ -294,7 +294,7 @@ SELECT * FROM empsalary WHERE row_number() OVER (ORDER BY salary) < 10 struct<> -- !query output org.apache.spark.sql.AnalysisException -It is not allowed to use window functions inside WHERE and HAVING clauses; +It is not allowed to use window functions inside WHERE clause; -- !query @@ -341,7 +341,7 @@ SELECT * FROM empsalary WHERE (rank() OVER (ORDER BY random())) > 10 struct<> -- !query output org.apache.spark.sql.AnalysisException -It is not allowed to use window functions inside WHERE and HAVING clauses; +It is not allowed to use window functions inside WHERE clause; -- !query @@ -350,7 +350,7 @@ SELECT * FROM empsalary WHERE rank() OVER (ORDER BY random()) struct<> -- !query output org.apache.spark.sql.AnalysisException -It is not allowed to use window functions inside WHERE and HAVING clauses; +It is not allowed to use window functions inside WHERE clause; -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index d398657ec0..f72ccaa63b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -665,40 +665,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest } test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { - def checkAnalysisError(df: => DataFrame): Unit = { + def checkAnalysisError(df: => DataFrame, clause: String): Unit = { val thrownException = the[AnalysisException] thrownBy { df.queryExecution.analyzed } - assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + assert(thrownException.message.contains(s"window functions inside $clause clause")) } - checkAnalysisError(testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1)) - checkAnalysisError(testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1)) + checkAnalysisError( + testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE") + checkAnalysisError( + testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE") checkAnalysisError( testData2.groupBy($"a") .agg(avg($"b").as("avgb")) - .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1)) + .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE") checkAnalysisError( testData2.groupBy($"a") .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where(rank().over(Window.orderBy($"a")) === 1)) + .where(rank().over(Window.orderBy($"a")) === 1), "WHERE") checkAnalysisError( testData2.groupBy($"a") .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1)) + .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE") - checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) - checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE") checkAnalysisError( - sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE") checkAnalysisError( - sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"), + "HAVING") + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"), + "HAVING") checkAnalysisError( sql( s"""SELECT a, MAX(b) |FROM testData2 |GROUP BY a - |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin), + "HAVING") } test("window functions in multiple selects") {