[SPARK-31519][SQL] Cast in having aggregate expressions returns the wrong result

### What changes were proposed in this pull request?
Add a new logical node AggregateWithHaving, and the parser should create this plan for HAVING. The analyzer resolves it to Filter(..., Aggregate(...)).

### Why are the changes needed?
The SQL parser in Spark creates Filter(..., Aggregate(...)) for the HAVING query, and Spark has a special analyzer rule ResolveAggregateFunctions to resolve the aggregate functions and grouping columns in the Filter operator.

It works for simple cases in a very tricky way as it relies on rule execution order:
1. Rule ResolveReferences hits the Aggregate operator and resolves attributes inside aggregate functions, but the function itself is still unresolved as it's an UnresolvedFunction. This stops resolving the Filter operator as the child Aggrege operator is still unresolved.
2. Rule ResolveFunctions resolves UnresolvedFunction. This makes the Aggrege operator resolved.
3. Rule ResolveAggregateFunctions resolves the Filter operator if its child is a resolved Aggregate. This rule can correctly resolve the grouping columns.

In the example query, I put a CAST, which needs to be resolved by rule ResolveTimeZone, which runs after ResolveAggregateFunctions. This breaks step 3 as the Aggregate operator is unresolved at that time. Then the analyzer starts next round and the Filter operator is resolved by ResolveReferences, which wrongly resolves the grouping columns.

See the demo below:
```
SELECT SUM(a) AS b, '2020-01-01' AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
```
The query's result is
```
+---+----------+
|  b|      fake|
+---+----------+
|  2|2020-01-01|
+---+----------+
```
But if we add CAST, it will return an empty result.
```
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
```

### Does this PR introduce any user-facing change?
Yes, bug fix for cast in having aggregate expressions.

### How was this patch tested?
New UT added.

Closes #28294 from xuanyuanking/SPARK-31519.

Authored-by: Yuanjian Li <xyliyuanjian@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Yuanjian Li 2020-04-28 08:11:41 +00:00 committed by Wenchen Fan
parent 079b3623c8
commit 6ed2dfbba1
9 changed files with 135 additions and 80 deletions

View file

@ -1239,13 +1239,13 @@ class Analyzer(
/**
* Resolves the attribute and extract value expressions(s) by traversing the
* input expression in top down manner. The traversal is done in top-down manner as
* we need to skip over unbound lamda function expression. The lamda expressions are
* we need to skip over unbound lambda function expression. The lambda expressions are
* resolved in a different rule [[ResolveLambdaVariables]]
*
* Example :
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
*
* In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]]
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]
*
* Note : In this routine, the unresolved attributes are resolved from the input plan's
* children attributes.
@ -1400,6 +1400,9 @@ class Analyzer(
notMatchedActions = newNotMatchedActions)
}
// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
case h: AggregateWithHaving => h
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
@ -2040,62 +2043,14 @@ class Analyzer(
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)
// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
grouping,
Alias(cond, "havingCondition")() :: Nil,
child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head
// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if grouping.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}
// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
} else {
f
}
} else {
f
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => f
}
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
resolveHaving(f, agg)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
@ -2166,6 +2121,63 @@ class Analyzer(
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filter.condition, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head
// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}
// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
} else {
filter
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
}
}
}
/**
@ -2590,11 +2602,14 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses")
failAnalysis("It is not allowed to use window functions inside WHERE clause")
case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")
// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>

View file

@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.ParserUtils
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
@ -538,3 +538,14 @@ case class UnresolvedOrdinal(ordinal: Int)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}
/**
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
*/
case class AggregateWithHaving(
havingCondition: Expression,
child: Aggregate)
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}

View file

@ -364,6 +364,14 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}
def having(
groupingExprs: Expression*)(
aggregateExprs: Expression*)(
havingCondition: Expression): LogicalPlan = {
AggregateWithHaving(havingCondition,
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
}
def window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],

View file

@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
Filter(predicate, plan)
plan match {
case aggregate: Aggregate =>
AggregateWithHaving(predicate, aggregate)
case _ =>
Filter(predicate, plan)
}
}
/**

View file

@ -208,7 +208,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
assertEqual(
"select a, b from db.c having x < 1",
table("db", "c").groupBy()('a, 'b).where('x < 1))
table("db", "c").having()('a, 'b)('x < 1))
assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
assertEqual("select from tbl", OneRowRelation().select('from.as("tbl")))
@ -574,8 +574,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select g from t group by g having a > (select b from s)",
table("t")
.groupBy('g)('g)
.where('a > ScalarSubquery(table("s").select('b))))
.having('g)('g)('a > ScalarSubquery(table("s").select('b))))
}
test("table reference") {

View file

@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
-- SPARK-20329: make sure we handle timezones correctly
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 5
-- Number of queries: 6
-- !query
@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint>
-- !query output
3
7
-- !query
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
-- !query schema
struct<b:bigint,fake:date>
-- !query output
2 2020-01-01

View file

@ -294,7 +294,7 @@ SELECT * FROM empsalary WHERE row_number() OVER (ORDER BY salary) < 10
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;
-- !query
@ -341,7 +341,7 @@ SELECT * FROM empsalary WHERE (rank() OVER (ORDER BY random())) > 10
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;
-- !query
@ -350,7 +350,7 @@ SELECT * FROM empsalary WHERE rank() OVER (ORDER BY random())
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;
-- !query

View file

@ -665,40 +665,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}
test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
def checkAnalysisError(df: => DataFrame): Unit = {
def checkAnalysisError(df: => DataFrame, clause: String): Unit = {
val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
assert(thrownException.message.contains(s"window functions inside $clause clause"))
}
checkAnalysisError(testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1))
checkAnalysisError(testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1))
checkAnalysisError(
testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE")
checkAnalysisError(
testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(avg($"b").as("avgb"))
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1))
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
.where(rank().over(Window.orderBy($"a")) === 1))
.where(rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1))
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE")
checkAnalysisError(
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"))
sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE")
checkAnalysisError(
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"))
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"),
"HAVING")
checkAnalysisError(
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"),
"HAVING")
checkAnalysisError(
sql(
s"""SELECT a, MAX(b)
|FROM testData2
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin),
"HAVING")
}
test("window functions in multiple selects") {