[SPARK-27986][SQL][FOLLOWUP] Respect filter in sql/toString of AggregateExpression
### What changes were proposed in this pull request? This pr intends to add filter information in the explain output of an aggregate (This is a follow-up of #26656). Without this pr: ``` scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").explain(true) == Parsed Logical Plan == 'Aggregate ['k], ['k, unresolvedalias('SUM('v, ('v > 3)), None)] +- 'UnresolvedRelation [t] == Analyzed Logical Plan == k: int, sum(v): bigint Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) AS sum(v)#3L] +- SubqueryAlias `default`.`t` +- Relation[k#0,v#1] parquet == Optimized Logical Plan == Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) AS sum(v)#3L] +- Relation[k#0,v#1] parquet == Physical Plan == HashAggregate(keys=[k#0], functions=[sum(cast(v#1 as bigint))], output=[k#0, sum(v)#3L]) +- Exchange hashpartitioning(k#0, 200), true, [id=#20] +- HashAggregate(keys=[k#0], functions=[partial_sum(cast(v#1 as bigint))], output=[k#0, sum#7L]) +- *(1) ColumnarToRow +- FileScan parquet default.t[k#0,v#1] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/t], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<k:int,v:int> scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").show() +---+------+ | k|sum(v)| +---+------+ +---+------+ ``` With this pr: ``` scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").explain(true) == Parsed Logical Plan == 'Aggregate ['k], ['k, unresolvedalias('SUM('v, ('v > 3)), None)] +- 'UnresolvedRelation [t] == Analyzed Logical Plan == k: int, sum(v) FILTER (v > 3): bigint Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) filter (v#1 > 3) AS sum(v) FILTER (v > 3)#5L] +- SubqueryAlias `default`.`t` +- Relation[k#0,v#1] parquet == Optimized Logical Plan == Aggregate [k#0], [k#0, sum(cast(v#1 as bigint)) filter (v#1 > 3) AS sum(v) FILTER (v > 3)#5L] +- Relation[k#0,v#1] parquet == Physical Plan == HashAggregate(keys=[k#0], functions=[sum(cast(v#1 as bigint))], output=[k#0, sum(v) FILTER (v > 3)#5L]) +- Exchange hashpartitioning(k#0, 200), true, [id=#20] +- HashAggregate(keys=[k#0], functions=[partial_sum(cast(v#1 as bigint)) filter (v#1 > 3)], output=[k#0, sum#9L]) +- *(1) ColumnarToRow +- FileScan parquet default.t[k#0,v#1] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/t], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<k:int,v:int> scala> sql("select k, SUM(v) filter (where v > 3) from t group by k").show() +---+---------------------+ | k|sum(v) FILTER (v > 3)| +---+---------------------+ +---+---------------------+ ``` ### Why are the changes needed? For better usability. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manually. Closes #27198 from maropu/SPARK-27986-FOLLOWUP. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
883ae331c3
commit
a3a42b30d0
|
@ -137,10 +137,11 @@ case class AggregateExpression(
|
|||
|
||||
@transient
|
||||
override lazy val references: AttributeSet = {
|
||||
mode match {
|
||||
case Partial | Complete => aggregateFunction.references ++ filterAttributes
|
||||
val aggAttributes = mode match {
|
||||
case Partial | Complete => aggregateFunction.references
|
||||
case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
|
||||
}
|
||||
aggAttributes ++ filterAttributes
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
|
@ -149,10 +150,20 @@ case class AggregateExpression(
|
|||
case PartialMerge => "merge_"
|
||||
case Final | Complete => ""
|
||||
}
|
||||
prefix + aggregateFunction.toAggString(isDistinct)
|
||||
val aggFuncStr = prefix + aggregateFunction.toAggString(isDistinct)
|
||||
filter match {
|
||||
case Some(predicate) => s"$aggFuncStr FILTER (WHERE $predicate)"
|
||||
case _ => aggFuncStr
|
||||
}
|
||||
}
|
||||
|
||||
override def sql: String = aggregateFunction.sql(isDistinct)
|
||||
override def sql: String = {
|
||||
val aggFuncStr = aggregateFunction.sql(isDistinct)
|
||||
filter match {
|
||||
case Some(predicate) => s"$aggFuncStr FILTER (WHERE ${predicate.sql})"
|
||||
case _ => aggFuncStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.aggregate
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
|
||||
|
||||
|
@ -27,6 +26,22 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto
|
|||
* Utility functions used by the query planner to convert our plan to new aggregation code path.
|
||||
*/
|
||||
object AggUtils {
|
||||
|
||||
private def mayRemoveAggFilters(exprs: Seq[AggregateExpression]): Seq[AggregateExpression] = {
|
||||
exprs.map { ae =>
|
||||
if (ae.filter.isDefined) {
|
||||
ae.mode match {
|
||||
// Aggregate filters are applicable only in partial/complete modes;
|
||||
// this method filters out them, otherwise.
|
||||
case Partial | Complete => ae
|
||||
case _ => ae.copy(filter = None)
|
||||
}
|
||||
} else {
|
||||
ae
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def createAggregate(
|
||||
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
|
||||
groupingExpressions: Seq[NamedExpression] = Nil,
|
||||
|
@ -41,7 +56,7 @@ object AggUtils {
|
|||
HashAggregateExec(
|
||||
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
|
||||
groupingExpressions = groupingExpressions,
|
||||
aggregateExpressions = aggregateExpressions,
|
||||
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
|
||||
aggregateAttributes = aggregateAttributes,
|
||||
initialInputBufferOffset = initialInputBufferOffset,
|
||||
resultExpressions = resultExpressions,
|
||||
|
@ -54,7 +69,7 @@ object AggUtils {
|
|||
ObjectHashAggregateExec(
|
||||
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
|
||||
groupingExpressions = groupingExpressions,
|
||||
aggregateExpressions = aggregateExpressions,
|
||||
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
|
||||
aggregateAttributes = aggregateAttributes,
|
||||
initialInputBufferOffset = initialInputBufferOffset,
|
||||
resultExpressions = resultExpressions,
|
||||
|
@ -63,7 +78,7 @@ object AggUtils {
|
|||
SortAggregateExec(
|
||||
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
|
||||
groupingExpressions = groupingExpressions,
|
||||
aggregateExpressions = aggregateExpressions,
|
||||
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
|
||||
aggregateAttributes = aggregateAttributes,
|
||||
initialInputBufferOffset = initialInputBufferOffset,
|
||||
resultExpressions = resultExpressions,
|
||||
|
|
|
@ -51,13 +51,13 @@ SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData
|
|||
struct<>
|
||||
-- !query 3 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
|
||||
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (WHERE (testdata.`a` >= 2)) AS `count(b) FILTER (WHERE (a >= 2))`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
|
||||
|
||||
|
||||
-- !query 4
|
||||
SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData
|
||||
-- !query 4 schema
|
||||
struct<count(a):bigint,count(b):bigint>
|
||||
struct<count(a) FILTER (WHERE (a = 1)):bigint,count(b) FILTER (WHERE (a > 1)):bigint>
|
||||
-- !query 4 output
|
||||
2 4
|
||||
|
||||
|
@ -65,7 +65,7 @@ struct<count(a):bigint,count(b):bigint>
|
|||
-- !query 5
|
||||
SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp
|
||||
-- !query 5 schema
|
||||
struct<count(id):bigint>
|
||||
struct<count(id) FILTER (WHERE (hiredate = DATE '2001-01-01')):bigint>
|
||||
-- !query 5 output
|
||||
2
|
||||
|
||||
|
@ -73,7 +73,7 @@ struct<count(id):bigint>
|
|||
-- !query 6
|
||||
SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp
|
||||
-- !query 6 schema
|
||||
struct<count(id):bigint>
|
||||
struct<count(id) FILTER (WHERE (hiredate = to_date('2001-01-01 00:00:00'))):bigint>
|
||||
-- !query 6 output
|
||||
2
|
||||
|
||||
|
@ -81,7 +81,7 @@ struct<count(id):bigint>
|
|||
-- !query 7
|
||||
SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp
|
||||
-- !query 7 schema
|
||||
struct<count(id):bigint>
|
||||
struct<count(id) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) = to_timestamp('2001-01-01 00:00:00'))):bigint>
|
||||
-- !query 7 output
|
||||
2
|
||||
|
||||
|
@ -89,7 +89,7 @@ struct<count(id):bigint>
|
|||
-- !query 8
|
||||
SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp
|
||||
-- !query 8 schema
|
||||
struct<count(id):bigint>
|
||||
struct<count(id) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) = 2001-01-01)):bigint>
|
||||
-- !query 8 output
|
||||
2
|
||||
|
||||
|
@ -97,7 +97,7 @@ struct<count(id):bigint>
|
|||
-- !query 9
|
||||
SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a
|
||||
-- !query 9 schema
|
||||
struct<a:int,count(b):bigint>
|
||||
struct<a:int,count(b) FILTER (WHERE (a >= 2)):bigint>
|
||||
-- !query 9 output
|
||||
1 0
|
||||
2 2
|
||||
|
@ -117,7 +117,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
|
|||
-- !query 11
|
||||
SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a
|
||||
-- !query 11 schema
|
||||
struct<count(a):bigint,count(b):bigint>
|
||||
struct<count(a) FILTER (WHERE (a >= 0)):bigint,count(b) FILTER (WHERE (a >= 3)):bigint>
|
||||
-- !query 11 output
|
||||
0 0
|
||||
2 0
|
||||
|
@ -128,7 +128,7 @@ struct<count(a):bigint,count(b):bigint>
|
|||
-- !query 12
|
||||
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id
|
||||
-- !query 12 schema
|
||||
struct<dept_id:int,sum(salary):double>
|
||||
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > DATE '2003-01-01')):double>
|
||||
-- !query 12 output
|
||||
10 200.0
|
||||
100 400.0
|
||||
|
@ -141,7 +141,7 @@ NULL NULL
|
|||
-- !query 13
|
||||
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id
|
||||
-- !query 13 schema
|
||||
struct<dept_id:int,sum(salary):double>
|
||||
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > to_date('2003-01-01'))):double>
|
||||
-- !query 13 output
|
||||
10 200.0
|
||||
100 400.0
|
||||
|
@ -154,7 +154,7 @@ NULL NULL
|
|||
-- !query 14
|
||||
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id
|
||||
-- !query 14 schema
|
||||
struct<dept_id:int,sum(salary):double>
|
||||
struct<dept_id:int,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) > to_timestamp('2003-01-01 00:00:00'))):double>
|
||||
-- !query 14 output
|
||||
10 200.0
|
||||
100 400.0
|
||||
|
@ -167,7 +167,7 @@ NULL NULL
|
|||
-- !query 15
|
||||
SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id
|
||||
-- !query 15 schema
|
||||
struct<dept_id:int,sum(salary):double>
|
||||
struct<dept_id:int,sum(salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) > 2003-01-01)):double>
|
||||
-- !query 15 output
|
||||
10 200.0
|
||||
100 400.0
|
||||
|
@ -180,7 +180,7 @@ NULL NULL
|
|||
-- !query 16
|
||||
SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1
|
||||
-- !query 16 schema
|
||||
struct<foo:string,count(a):bigint>
|
||||
struct<foo:string,count(a) FILTER (WHERE (b <= 2)):bigint>
|
||||
-- !query 16 output
|
||||
foo 6
|
||||
|
||||
|
@ -188,7 +188,7 @@ foo 6
|
|||
-- !query 17
|
||||
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1
|
||||
-- !query 17 schema
|
||||
struct<foo:string,sum(salary):double>
|
||||
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= DATE '2003-01-01')):double>
|
||||
-- !query 17 output
|
||||
foo 1350.0
|
||||
|
||||
|
@ -196,7 +196,7 @@ foo 1350.0
|
|||
-- !query 18
|
||||
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1
|
||||
-- !query 18 schema
|
||||
struct<foo:string,sum(salary):double>
|
||||
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= to_date('2003-01-01'))):double>
|
||||
-- !query 18 output
|
||||
foo 1350.0
|
||||
|
||||
|
@ -204,7 +204,7 @@ foo 1350.0
|
|||
-- !query 19
|
||||
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1
|
||||
-- !query 19 schema
|
||||
struct<foo:string,sum(salary):double>
|
||||
struct<foo:string,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) >= to_timestamp('2003-01-01'))):double>
|
||||
-- !query 19 output
|
||||
foo 1350.0
|
||||
|
||||
|
@ -212,7 +212,7 @@ foo 1350.0
|
|||
-- !query 20
|
||||
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id
|
||||
-- !query 20 schema
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double>
|
||||
-- !query 20 output
|
||||
10 2 2 400.0 NULL
|
||||
100 2 2 800.0 800.0
|
||||
|
@ -225,7 +225,7 @@ NULL 1 1 400.0 400.0
|
|||
-- !query 21
|
||||
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
|
||||
-- !query 21 schema
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
|
||||
-- !query 21 output
|
||||
10 2 2 400.0 NULL
|
||||
100 2 2 800.0 800.0
|
||||
|
@ -238,7 +238,7 @@ NULL 1 1 400.0 NULL
|
|||
-- !query 22
|
||||
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id
|
||||
-- !query 22 schema
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double>
|
||||
-- !query 22 output
|
||||
10 2 2 400.0 NULL
|
||||
100 2 2 NULL 800.0
|
||||
|
@ -251,7 +251,7 @@ NULL 1 1 NULL 400.0
|
|||
-- !query 23
|
||||
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
|
||||
-- !query 23 schema
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
|
||||
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
|
||||
-- !query 23 output
|
||||
10 2 2 400.0 NULL
|
||||
100 2 2 NULL 800.0
|
||||
|
@ -264,7 +264,7 @@ NULL 1 1 NULL NULL
|
|||
-- !query 24
|
||||
SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1
|
||||
-- !query 24 schema
|
||||
struct<foo:string,approx_count_distinct(a):bigint>
|
||||
struct<foo:string,approx_count_distinct(a) FILTER (WHERE (b >= 0)):bigint>
|
||||
-- !query 24 output
|
||||
|
||||
|
||||
|
@ -272,7 +272,7 @@ struct<foo:string,approx_count_distinct(a):bigint>
|
|||
-- !query 25
|
||||
SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
|
||||
-- !query 25 schema
|
||||
struct<foo:string,max(named_struct(a, a)):struct<a:int>>
|
||||
struct<foo:string,max(named_struct(a, a)) FILTER (WHERE (b >= 1)):struct<a:int>>
|
||||
-- !query 25 output
|
||||
|
||||
|
||||
|
@ -280,7 +280,7 @@ struct<foo:string,max(named_struct(a, a)):struct<a:int>>
|
|||
-- !query 26
|
||||
SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b
|
||||
-- !query 26 schema
|
||||
struct<(a + b):int,count(b):bigint>
|
||||
struct<(a + b):int,count(b) FILTER (WHERE (b >= 2)):bigint>
|
||||
-- !query 26 output
|
||||
2 0
|
||||
3 1
|
||||
|
@ -301,7 +301,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
|
|||
-- !query 28
|
||||
SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1
|
||||
-- !query 28 schema
|
||||
struct<((a + 1) + 1):int,count(b):bigint>
|
||||
struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint>
|
||||
-- !query 28 output
|
||||
3 2
|
||||
4 2
|
||||
|
@ -312,7 +312,7 @@ NULL 1
|
|||
-- !query 29
|
||||
SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k
|
||||
-- !query 29 schema
|
||||
struct<k:int,count(b):bigint>
|
||||
struct<k:int,count(b) FILTER (WHERE (b > 0)):bigint>
|
||||
-- !query 29 output
|
||||
1 2
|
||||
2 2
|
||||
|
@ -327,7 +327,7 @@ SELECT emp.dept_id,
|
|||
FROM emp
|
||||
GROUP BY dept_id
|
||||
-- !query 30 schema
|
||||
struct<dept_id:int,avg(salary):double,avg(salary):double>
|
||||
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (id > scalarsubquery())):double>
|
||||
-- !query 30 output
|
||||
10 133.33333333333334 NULL
|
||||
100 400.0 400.0
|
||||
|
@ -344,7 +344,7 @@ SELECT emp.dept_id,
|
|||
FROM emp
|
||||
GROUP BY dept_id
|
||||
-- !query 31 schema
|
||||
struct<dept_id:int,avg(salary):double,avg(salary):double>
|
||||
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (dept_id = scalarsubquery())):double>
|
||||
-- !query 31 output
|
||||
10 133.33333333333334 133.33333333333334
|
||||
100 400.0 NULL
|
||||
|
@ -366,7 +366,7 @@ GROUP BY dept_id
|
|||
struct<>
|
||||
-- !query 32 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x]
|
||||
: +- Project [state#x]
|
||||
: +- Filter (dept_id#x = outer(dept_id#x))
|
||||
: +- SubqueryAlias `dept`
|
||||
|
@ -392,7 +392,7 @@ GROUP BY dept_id
|
|||
struct<>
|
||||
-- !query 33 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x]
|
||||
: +- Project [state#x]
|
||||
: +- Filter (dept_id#x = outer(dept_id#x))
|
||||
: +- SubqueryAlias `dept`
|
||||
|
@ -417,7 +417,7 @@ GROUP BY dept_id
|
|||
struct<>
|
||||
-- !query 34 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x]
|
||||
: +- Distinct
|
||||
: +- Project [dept_id#x]
|
||||
: +- SubqueryAlias `dept`
|
||||
|
@ -442,7 +442,7 @@ GROUP BY dept_id
|
|||
struct<>
|
||||
-- !query 35 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
|
||||
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x]
|
||||
: +- Distinct
|
||||
: +- Project [dept_id#x]
|
||||
: +- SubqueryAlias `dept`
|
||||
|
|
|
@ -14,7 +14,7 @@ It is not allowed to use an aggregate function in the argument of another aggreg
|
|||
-- !query 1
|
||||
select min(unique1) filter (where unique1 > 100) from tenk1
|
||||
-- !query 1 schema
|
||||
struct<min(unique1):int>
|
||||
struct<min(unique1) FILTER (WHERE (unique1 > 100)):int>
|
||||
-- !query 1 output
|
||||
101
|
||||
|
||||
|
@ -22,7 +22,7 @@ struct<min(unique1):int>
|
|||
-- !query 2
|
||||
select sum(1/ten) filter (where ten > 0) from tenk1
|
||||
-- !query 2 schema
|
||||
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))):double>
|
||||
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))) FILTER (WHERE (ten > 0)):double>
|
||||
-- !query 2 output
|
||||
2828.9682539682954
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import scala.collection.parallel.immutable.ParVector
|
|||
import org.apache.spark.{AccumulatorSuite, SparkException}
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRow
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
|
||||
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
|
||||
import org.apache.spark.sql.catalyst.util.StringUtils
|
||||
import org.apache.spark.sql.execution.HiveResult.hiveResultString
|
||||
|
@ -2843,16 +2844,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
|
|||
val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2"
|
||||
val df = sql(query)
|
||||
val physical = df.queryExecution.sparkPlan
|
||||
val aggregateExpressions = physical.collectFirst {
|
||||
val aggregateExpressions = physical.collect {
|
||||
case agg: HashAggregateExec => agg.aggregateExpressions
|
||||
case agg: ObjectHashAggregateExec => agg.aggregateExpressions
|
||||
}.flatten
|
||||
aggregateExpressions.foreach { expr =>
|
||||
if (expr.mode == Complete || expr.mode == Partial) {
|
||||
assert(expr.filter.isDefined)
|
||||
} else {
|
||||
assert(expr.filter.isEmpty)
|
||||
}
|
||||
}
|
||||
assert(aggregateExpressions.isDefined)
|
||||
assert(aggregateExpressions.get.size == 1)
|
||||
aggregateExpressions.get.foreach { expr =>
|
||||
assert(expr.filter.isDefined)
|
||||
}
|
||||
checkAnswer(df, Row(funcToResult._2) :: Nil)
|
||||
checkAnswer(df, Row(funcToResult._2))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2860,15 +2863,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
|
|||
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
|
||||
val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2")
|
||||
val physical = df.queryExecution.sparkPlan
|
||||
val aggregateExpressions = physical.collectFirst {
|
||||
val aggregateExpressions = physical.collect {
|
||||
case agg: SortAggregateExec => agg.aggregateExpressions
|
||||
}.flatten
|
||||
aggregateExpressions.foreach { expr =>
|
||||
if (expr.mode == Complete || expr.mode == Partial) {
|
||||
assert(expr.filter.isDefined)
|
||||
} else {
|
||||
assert(expr.filter.isEmpty)
|
||||
}
|
||||
}
|
||||
assert(aggregateExpressions.isDefined)
|
||||
assert(aggregateExpressions.get.size == 1)
|
||||
aggregateExpressions.get.foreach { expr =>
|
||||
assert(expr.filter.isDefined)
|
||||
}
|
||||
checkAnswer(df, Row(3) :: Nil)
|
||||
checkAnswer(df, Row(3))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue