[SPARK-27986][SQL][FOLLOWUP] Respect filter in sql/toString of AggregateExpression

### What changes were proposed in this pull request?

This pr intends to add filter information in the explain output of an aggregate (This is a follow-up of #26656).

Without this pr:
```
scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").explain(true)
== Parsed Logical Plan ==
'Aggregate ['k], ['k, unresolvedalias('SUM('v, ('v > 3)), None)]
+- 'UnresolvedRelation [t]

== Analyzed Logical Plan ==
k: int, sum(v): bigint
Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) AS sum(v)#3L]
+- SubqueryAlias `default`.`t`
   +- Relation[k#0,v#1] parquet

== Optimized Logical Plan ==
Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) AS sum(v)#3L]
+- Relation[k#0,v#1] parquet

== Physical Plan ==
HashAggregate(keys=[k#0], functions=[sum(cast(v#1 as bigint))], output=[k#0, sum(v)#3L])
+- Exchange hashpartitioning(k#0, 200), true, [id=#20]
   +- HashAggregate(keys=[k#0], functions=[partial_sum(cast(v#1 as bigint))], output=[k#0, sum#7L])
      +- *(1) ColumnarToRow
         +- FileScan parquet default.t[k#0,v#1] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/t], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<k:int,v:int>

scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").show()
+---+------+
|  k|sum(v)|
+---+------+
+---+------+
```

With this pr:
```
scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").explain(true)
== Parsed Logical Plan ==
'Aggregate ['k], ['k, unresolvedalias('SUM('v, ('v > 3)), None)]
+- 'UnresolvedRelation [t]

== Analyzed Logical Plan ==
k: int, sum(v) FILTER (v > 3): bigint
Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) filter (v#1 > 3) AS sum(v) FILTER (v > 3)#5L]
+- SubqueryAlias `default`.`t`
   +- Relation[k#0,v#1] parquet

== Optimized Logical Plan ==
Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) filter (v#1 > 3) AS sum(v) FILTER (v > 3)#5L]
+- Relation[k#0,v#1] parquet

== Physical Plan ==
HashAggregate(keys=[k#0], functions=[sum(cast(v#1 as bigint))], output=[k#0, sum(v) FILTER (v > 3)#5L])
+- Exchange hashpartitioning(k#0, 200), true, [id=#20]
   +- HashAggregate(keys=[k#0], functions=[partial_sum(cast(v#1 as bigint)) filter (v#1 > 3)], output=[k#0, sum#9L])
      +- *(1) ColumnarToRow
         +- FileScan parquet default.t[k#0,v#1] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/t], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<k:int,v:int>

scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").show()
+---+---------------------+
|  k|sum(v) FILTER (v > 3)|
+---+---------------------+
+---+---------------------+
```

### Why are the changes needed?

For better usability.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Manually.

Closes #27198 from maropu/SPARK-27986-FOLLOWUP.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Takeshi Yamamuro 2020-01-16 11:11:36 +09:00
parent 883ae331c3
commit a3a42b30d0
5 changed files with 86 additions and 55 deletions

View file

@ -137,10 +137,11 @@ case class AggregateExpression(
@transient
override lazy val references: AttributeSet = {
mode match {
case Partial | Complete => aggregateFunction.references ++ filterAttributes
val aggAttributes = mode match {
case Partial | Complete => aggregateFunction.references
case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
}
aggAttributes ++ filterAttributes
}
override def toString: String = {
@ -149,10 +150,20 @@ case class AggregateExpression(
case PartialMerge => "merge_"
case Final | Complete => ""
}
prefix + aggregateFunction.toAggString(isDistinct)
val aggFuncStr = prefix + aggregateFunction.toAggString(isDistinct)
filter match {
case Some(predicate) => s"$aggFuncStr FILTER (WHERE $predicate)"
case _ => aggFuncStr
}
}
override def sql: String = aggregateFunction.sql(isDistinct)
override def sql: String = {
val aggFuncStr = aggregateFunction.sql(isDistinct)
filter match {
case Some(predicate) => s"$aggFuncStr FILTER (WHERE ${predicate.sql})"
case _ => aggFuncStr
}
}
}
/**

View file

@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
@ -27,6 +26,22 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object AggUtils {
private def mayRemoveAggFilters(exprs: Seq[AggregateExpression]): Seq[AggregateExpression] = {
exprs.map { ae =>
if (ae.filter.isDefined) {
ae.mode match {
// Aggregate filters are applicable only in partial/complete modes;
// this method filters out them, otherwise.
case Partial | Complete => ae
case _ => ae.copy(filter = None)
}
} else {
ae
}
}
}
private def createAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
groupingExpressions: Seq[NamedExpression] = Nil,
@ -41,7 +56,7 @@ object AggUtils {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
@ -54,7 +69,7 @@ object AggUtils {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
@ -63,7 +78,7 @@ object AggUtils {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,

View file

@ -51,13 +51,13 @@ SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (WHERE (testdata.`a` >= 2)) AS `count(b) FILTER (WHERE (a >= 2))`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
-- !query 4
SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData
-- !query 4 schema
struct<count(a):bigint,count(b):bigint>
struct<count(a) FILTER (WHERE (a = 1)):bigint,count(b) FILTER (WHERE (a > 1)):bigint>
-- !query 4 output
2 4
@ -65,7 +65,7 @@ struct<count(a):bigint,count(b):bigint>
-- !query 5
SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp
-- !query 5 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (hiredate = DATE '2001-01-01')):bigint>
-- !query 5 output
2
@ -73,7 +73,7 @@ struct<count(id):bigint>
-- !query 6
SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp
-- !query 6 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (hiredate = to_date('2001-01-01 00:00:00'))):bigint>
-- !query 6 output
2
@ -81,7 +81,7 @@ struct<count(id):bigint>
-- !query 7
SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp
-- !query 7 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) = to_timestamp('2001-01-01 00:00:00'))):bigint>
-- !query 7 output
2
@ -89,7 +89,7 @@ struct<count(id):bigint>
-- !query 8
SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp
-- !query 8 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) = 2001-01-01)):bigint>
-- !query 8 output
2
@ -97,7 +97,7 @@ struct<count(id):bigint>
-- !query 9
SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a
-- !query 9 schema
struct<a:int,count(b):bigint>
struct<a:int,count(b) FILTER (WHERE (a >= 2)):bigint>
-- !query 9 output
1 0
2 2
@ -117,7 +117,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
-- !query 11
SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a
-- !query 11 schema
struct<count(a):bigint,count(b):bigint>
struct<count(a) FILTER (WHERE (a >= 0)):bigint,count(b) FILTER (WHERE (a >= 3)):bigint>
-- !query 11 output
0 0
2 0
@ -128,7 +128,7 @@ struct<count(a):bigint,count(b):bigint>
-- !query 12
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id
-- !query 12 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > DATE '2003-01-01')):double>
-- !query 12 output
10 200.0
100 400.0
@ -141,7 +141,7 @@ NULL NULL
-- !query 13
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id
-- !query 13 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > to_date('2003-01-01'))):double>
-- !query 13 output
10 200.0
100 400.0
@ -154,7 +154,7 @@ NULL NULL
-- !query 14
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id
-- !query 14 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) > to_timestamp('2003-01-01 00:00:00'))):double>
-- !query 14 output
10 200.0
100 400.0
@ -167,7 +167,7 @@ NULL NULL
-- !query 15
SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id
-- !query 15 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) > 2003-01-01)):double>
-- !query 15 output
10 200.0
100 400.0
@ -180,7 +180,7 @@ NULL NULL
-- !query 16
SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1
-- !query 16 schema
struct<foo:string,count(a):bigint>
struct<foo:string,count(a) FILTER (WHERE (b <= 2)):bigint>
-- !query 16 output
foo 6
@ -188,7 +188,7 @@ foo 6
-- !query 17
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1
-- !query 17 schema
struct<foo:string,sum(salary):double>
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= DATE '2003-01-01')):double>
-- !query 17 output
foo 1350.0
@ -196,7 +196,7 @@ foo 1350.0
-- !query 18
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1
-- !query 18 schema
struct<foo:string,sum(salary):double>
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= to_date('2003-01-01'))):double>
-- !query 18 output
foo 1350.0
@ -204,7 +204,7 @@ foo 1350.0
-- !query 19
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1
-- !query 19 schema
struct<foo:string,sum(salary):double>
struct<foo:string,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) >= to_timestamp('2003-01-01'))):double>
-- !query 19 output
foo 1350.0
@ -212,7 +212,7 @@ foo 1350.0
-- !query 20
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id
-- !query 20 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double>
-- !query 20 output
10 2 2 400.0 NULL
100 2 2 800.0 800.0
@ -225,7 +225,7 @@ NULL 1 1 400.0 400.0
-- !query 21
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
-- !query 21 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
-- !query 21 output
10 2 2 400.0 NULL
100 2 2 800.0 800.0
@ -238,7 +238,7 @@ NULL 1 1 400.0 NULL
-- !query 22
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id
-- !query 22 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double>
-- !query 22 output
10 2 2 400.0 NULL
100 2 2 NULL 800.0
@ -251,7 +251,7 @@ NULL 1 1 NULL 400.0
-- !query 23
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
-- !query 23 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
-- !query 23 output
10 2 2 400.0 NULL
100 2 2 NULL 800.0
@ -264,7 +264,7 @@ NULL 1 1 NULL NULL
-- !query 24
SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1
-- !query 24 schema
struct<foo:string,approx_count_distinct(a):bigint>
struct<foo:string,approx_count_distinct(a) FILTER (WHERE (b >= 0)):bigint>
-- !query 24 output
@ -272,7 +272,7 @@ struct<foo:string,approx_count_distinct(a):bigint>
-- !query 25
SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
-- !query 25 schema
struct<foo:string,max(named_struct(a, a)):struct<a:int>>
struct<foo:string,max(named_struct(a, a)) FILTER (WHERE (b >= 1)):struct<a:int>>
-- !query 25 output
@ -280,7 +280,7 @@ struct<foo:string,max(named_struct(a, a)):struct<a:int>>
-- !query 26
SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b
-- !query 26 schema
struct<(a + b):int,count(b):bigint>
struct<(a + b):int,count(b) FILTER (WHERE (b >= 2)):bigint>
-- !query 26 output
2 0
3 1
@ -301,7 +301,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
-- !query 28
SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1
-- !query 28 schema
struct<((a + 1) + 1):int,count(b):bigint>
struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint>
-- !query 28 output
3 2
4 2
@ -312,7 +312,7 @@ NULL 1
-- !query 29
SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k
-- !query 29 schema
struct<k:int,count(b):bigint>
struct<k:int,count(b) FILTER (WHERE (b > 0)):bigint>
-- !query 29 output
1 2
2 2
@ -327,7 +327,7 @@ SELECT emp.dept_id,
FROM emp
GROUP BY dept_id
-- !query 30 schema
struct<dept_id:int,avg(salary):double,avg(salary):double>
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (id > scalarsubquery())):double>
-- !query 30 output
10 133.33333333333334 NULL
100 400.0 400.0
@ -344,7 +344,7 @@ SELECT emp.dept_id,
FROM emp
GROUP BY dept_id
-- !query 31 schema
struct<dept_id:int,avg(salary):double,avg(salary):double>
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (dept_id = scalarsubquery())):double>
-- !query 31 output
10 133.33333333333334 133.33333333333334
100 400.0 NULL
@ -366,7 +366,7 @@ GROUP BY dept_id
struct<>
-- !query 32 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x]
: +- Project [state#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias `dept`
@ -392,7 +392,7 @@ GROUP BY dept_id
struct<>
-- !query 33 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x]
: +- Project [state#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias `dept`
@ -417,7 +417,7 @@ GROUP BY dept_id
struct<>
-- !query 34 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x]
: +- Distinct
: +- Project [dept_id#x]
: +- SubqueryAlias `dept`
@ -442,7 +442,7 @@ GROUP BY dept_id
struct<>
-- !query 35 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x]
: +- Distinct
: +- Project [dept_id#x]
: +- SubqueryAlias `dept`

View file

@ -14,7 +14,7 @@ It is not allowed to use an aggregate function in the argument of another aggreg
-- !query 1
select min(unique1) filter (where unique1 > 100) from tenk1
-- !query 1 schema
struct<min(unique1):int>
struct<min(unique1) FILTER (WHERE (unique1 > 100)):int>
-- !query 1 output
101
@ -22,7 +22,7 @@ struct<min(unique1):int>
-- !query 2
select sum(1/ten) filter (where ten > 0) from tenk1
-- !query 2 schema
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))):double>
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))) FILTER (WHERE (ten > 0)):double>
-- !query 2 output
2828.9682539682954

View file

@ -27,6 +27,7 @@ import scala.collection.parallel.immutable.ParVector
import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.HiveResult.hiveResultString
@ -2843,16 +2844,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2"
val df = sql(query)
val physical = df.queryExecution.sparkPlan
val aggregateExpressions = physical.collectFirst {
val aggregateExpressions = physical.collect {
case agg: HashAggregateExec => agg.aggregateExpressions
case agg: ObjectHashAggregateExec => agg.aggregateExpressions
}.flatten
aggregateExpressions.foreach { expr =>
if (expr.mode == Complete || expr.mode == Partial) {
assert(expr.filter.isDefined)
} else {
assert(expr.filter.isEmpty)
}
}
assert(aggregateExpressions.isDefined)
assert(aggregateExpressions.get.size == 1)
aggregateExpressions.get.foreach { expr =>
assert(expr.filter.isDefined)
}
checkAnswer(df, Row(funcToResult._2) :: Nil)
checkAnswer(df, Row(funcToResult._2))
}
}
@ -2860,15 +2863,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2")
val physical = df.queryExecution.sparkPlan
val aggregateExpressions = physical.collectFirst {
val aggregateExpressions = physical.collect {
case agg: SortAggregateExec => agg.aggregateExpressions
}.flatten
aggregateExpressions.foreach { expr =>
if (expr.mode == Complete || expr.mode == Partial) {
assert(expr.filter.isDefined)
} else {
assert(expr.filter.isEmpty)
}
}
assert(aggregateExpressions.isDefined)
assert(aggregateExpressions.get.size == 1)
aggregateExpressions.get.foreach { expr =>
assert(expr.filter.isDefined)
}
checkAnswer(df, Row(3) :: Nil)
checkAnswer(df, Row(3))
}
}