[SPARK-14471][SQL] Aliases in SELECT could be used in GROUP BY
## What changes were proposed in this pull request? This pr added a new rule in `Analyzer` to resolve aliases in `GROUP BY`. The current master throws an exception if `GROUP BY` clauses have aliases in `SELECT`; ``` scala> spark.sql("select a a1, a1 + 1 as b, count(1) from t group by a1") org.apache.spark.sql.AnalysisException: cannot resolve '`a1`' given input columns: [a]; line 1 pos 51; 'Aggregate ['a1], [a#83L AS a1#87L, ('a1 + 1) AS b#88, count(1) AS count(1)#90L] +- SubqueryAlias t +- Project [id#80L AS a#83L] +- Range (0, 10, step=1, splits=Some(8)) at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:77) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:74) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) ``` ## How was this patch tested? Added tests in `SQLQuerySuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #17191 from maropu/SPARK-14471.
This commit is contained in:
parent
e3c8160433
commit
59e3a56444
|
@ -136,6 +136,7 @@ class Analyzer(
|
|||
ResolveGroupingAnalytics ::
|
||||
ResolvePivot ::
|
||||
ResolveOrdinalInOrderByAndGroupBy ::
|
||||
ResolveAggAliasInGroupBy ::
|
||||
ResolveMissingReferences ::
|
||||
ExtractGenerator ::
|
||||
ResolveGenerate ::
|
||||
|
@ -172,7 +173,7 @@ class Analyzer(
|
|||
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
|
||||
*/
|
||||
object CTESubstitution extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case With(child, relations) =>
|
||||
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
|
||||
case (resolved, (name, relation)) =>
|
||||
|
@ -200,7 +201,7 @@ class Analyzer(
|
|||
* Substitute child plan with WindowSpecDefinitions.
|
||||
*/
|
||||
object WindowsSubstitution extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
|
||||
case WithWindowDefinition(windowDefinitions, child) =>
|
||||
child.transform {
|
||||
|
@ -242,7 +243,7 @@ class Analyzer(
|
|||
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
|
||||
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
|
||||
Aggregate(groups, assignAliases(aggs), child)
|
||||
|
||||
|
@ -614,7 +615,7 @@ class Analyzer(
|
|||
case _ => plan
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
|
||||
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
|
||||
case v: View =>
|
||||
|
@ -786,7 +787,7 @@ class Analyzer(
|
|||
}
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p: LogicalPlan if !p.childrenResolved => p
|
||||
|
||||
// If the projection list contains Stars, expand it.
|
||||
|
@ -844,11 +845,10 @@ class Analyzer(
|
|||
|
||||
case q: LogicalPlan =>
|
||||
logTrace(s"Attempting to resolve ${q.simpleString}")
|
||||
q transformExpressionsUp {
|
||||
q.transformExpressionsUp {
|
||||
case u @ UnresolvedAttribute(nameParts) =>
|
||||
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
|
||||
val result =
|
||||
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
|
||||
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
|
||||
val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
|
||||
logDebug(s"Resolving $u to $result")
|
||||
result
|
||||
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
|
||||
|
@ -961,7 +961,7 @@ class Analyzer(
|
|||
* have no effect on the results.
|
||||
*/
|
||||
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
// Replace the index with the related attribute for ORDER BY,
|
||||
// which is a 1-base position of the projection list.
|
||||
|
@ -997,6 +997,27 @@ class Analyzer(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses.
|
||||
* This rule is expected to run after [[ResolveReferences]] applied.
|
||||
*/
|
||||
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {
|
||||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case agg @ Aggregate(groups, aggs, child)
|
||||
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
|
||||
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
|
||||
// This is a strict check though, we put this to apply the rule only in alias expressions
|
||||
def notResolvableByChild(attrName: String): Boolean =
|
||||
!child.output.exists(a => resolver(a.name, attrName))
|
||||
agg.copy(groupingExpressions = groups.map {
|
||||
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
|
||||
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
|
||||
case e => e
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
|
||||
* clause. This rule detects such queries and adds the required attributes to the original
|
||||
|
@ -1006,7 +1027,7 @@ class Analyzer(
|
|||
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
|
||||
*/
|
||||
object ResolveMissingReferences extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
|
||||
case sa @ Sort(_, _, child: Aggregate) => sa
|
||||
|
||||
|
@ -1130,7 +1151,7 @@ class Analyzer(
|
|||
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
|
||||
*/
|
||||
object ResolveFunctions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case q: LogicalPlan =>
|
||||
q transformExpressions {
|
||||
case u if !u.childrenResolved => u // Skip until children are resolved.
|
||||
|
@ -1469,7 +1490,7 @@ class Analyzer(
|
|||
/**
|
||||
* Resolve and rewrite all subqueries in an operator tree..
|
||||
*/
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
|
||||
// its child for resolution.
|
||||
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
|
||||
|
@ -1484,7 +1505,7 @@ class Analyzer(
|
|||
* Turns projections that contain aggregate expressions into aggregations.
|
||||
*/
|
||||
object GlobalAggregates extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case Project(projectList, child) if containsAggregates(projectList) =>
|
||||
Aggregate(Nil, projectList, child)
|
||||
}
|
||||
|
@ -1510,7 +1531,7 @@ class Analyzer(
|
|||
* underlying aggregate operator and then projected away after the original operator.
|
||||
*/
|
||||
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case filter @ Filter(havingCondition,
|
||||
aggregate @ Aggregate(grouping, originalAggExprs, child))
|
||||
if aggregate.resolved =>
|
||||
|
@ -1682,7 +1703,7 @@ class Analyzer(
|
|||
}
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
|
||||
val nestedGenerator = projectList.find(hasNestedGenerator).get
|
||||
throw new AnalysisException("Generators are not supported when it's nested in " +
|
||||
|
@ -1740,7 +1761,7 @@ class Analyzer(
|
|||
* that wrap the [[Generator]].
|
||||
*/
|
||||
object ResolveGenerate extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case g: Generate if !g.child.resolved || !g.generator.resolved => g
|
||||
case g: Generate if !g.resolved =>
|
||||
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
|
||||
|
@ -2057,7 +2078,7 @@ class Analyzer(
|
|||
* put them into an inner Project and finally project them away at the outer Project.
|
||||
*/
|
||||
object PullOutNondeterministic extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p if !p.resolved => p // Skip unresolved nodes.
|
||||
case p: Project => p
|
||||
case f: Filter => f
|
||||
|
@ -2102,7 +2123,7 @@ class Analyzer(
|
|||
* and we should return null if the input is null.
|
||||
*/
|
||||
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p if !p.resolved => p // Skip unresolved nodes.
|
||||
|
||||
case p => p transformExpressionsUp {
|
||||
|
@ -2167,7 +2188,7 @@ class Analyzer(
|
|||
* Then apply a Project on a normal Join to eliminate natural or using join.
|
||||
*/
|
||||
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
|
||||
if left.resolved && right.resolved && j.duplicateResolved =>
|
||||
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
|
||||
|
@ -2232,7 +2253,7 @@ class Analyzer(
|
|||
* to the given input attributes.
|
||||
*/
|
||||
object ResolveDeserializer extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
case p if p.resolved => p
|
||||
|
||||
|
@ -2318,7 +2339,7 @@ class Analyzer(
|
|||
* constructed is an inner class.
|
||||
*/
|
||||
object ResolveNewInstance extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
case p if p.resolved => p
|
||||
|
||||
|
@ -2352,7 +2373,7 @@ class Analyzer(
|
|||
"type of the field in the target object")
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
case p if p.resolved => p
|
||||
|
||||
|
@ -2406,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
|
|||
case other => trimAliases(other)
|
||||
}
|
||||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case Project(projectList, child) =>
|
||||
val cleanedProjectList =
|
||||
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
|
||||
|
@ -2474,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
|
|||
* @return the logical plan that will generate the time windows using the Expand operator, with
|
||||
* the Filter operator for correctness and Project for usability.
|
||||
*/
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
|
||||
case p: LogicalPlan if p.children.size == 1 =>
|
||||
val child = p.children.head
|
||||
val windowExpressions =
|
||||
|
|
|
@ -421,6 +421,12 @@ object SQLConf {
|
|||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases")
|
||||
.doc("When true, aliases in a select list can be used in group by clauses. When false, " +
|
||||
"an analysis exception is thrown in the case.")
|
||||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
// The output committer class used by data sources. The specified class needs to be a
|
||||
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
|
||||
val OUTPUT_COMMITTER_CLASS =
|
||||
|
@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging {
|
|||
|
||||
def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
|
||||
|
||||
def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES)
|
||||
|
||||
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
|
||||
|
||||
def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
|
||||
|
|
|
@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
|
|||
-- group by ordinal followed by having
|
||||
select count(a), a from (select 1 as a) tmp group by 2 having a > 0;
|
||||
|
||||
-- mixed cases: group-by ordinals and aliases
|
||||
select a, a AS k, count(b) from data group by k, 1;
|
||||
|
||||
-- turn of group by ordinal
|
||||
set spark.sql.groupByOrdinal=false;
|
||||
|
||||
|
|
|
@ -35,3 +35,21 @@ FROM testData;
|
|||
|
||||
-- Aggregate with foldable input and multiple distinct groups.
|
||||
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
|
||||
|
||||
-- Aliases in SELECT could be used in GROUP BY
|
||||
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
|
||||
SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1;
|
||||
|
||||
-- Aggregate functions cannot be used in GROUP BY
|
||||
SELECT COUNT(b) AS k FROM testData GROUP BY k;
|
||||
|
||||
-- Test data.
|
||||
CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
|
||||
(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v);
|
||||
SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a;
|
||||
|
||||
-- turn off group by aliases
|
||||
set spark.sql.groupByAliases=false;
|
||||
|
||||
-- Check analysis exceptions
|
||||
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 19
|
||||
-- Number of queries: 20
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -173,16 +173,26 @@ struct<count(a):bigint,a:int>
|
|||
|
||||
|
||||
-- !query 17
|
||||
set spark.sql.groupByOrdinal=false
|
||||
select a, a AS k, count(b) from data group by k, 1
|
||||
-- !query 17 schema
|
||||
struct<key:string,value:string>
|
||||
struct<a:int,k:int,count(b):bigint>
|
||||
-- !query 17 output
|
||||
spark.sql.groupByOrdinal false
|
||||
1 1 2
|
||||
2 2 2
|
||||
3 3 2
|
||||
|
||||
|
||||
-- !query 18
|
||||
select sum(b) from data group by -1
|
||||
set spark.sql.groupByOrdinal=false
|
||||
-- !query 18 schema
|
||||
struct<sum(b):bigint>
|
||||
struct<key:string,value:string>
|
||||
-- !query 18 output
|
||||
spark.sql.groupByOrdinal false
|
||||
|
||||
|
||||
-- !query 19
|
||||
select sum(b) from data group by -1
|
||||
-- !query 19 schema
|
||||
struct<sum(b):bigint>
|
||||
-- !query 19 output
|
||||
9
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 15
|
||||
-- Number of queries: 22
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -139,3 +139,67 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS
|
|||
struct<count(DISTINCT b):bigint,count(DISTINCT b, c):bigint>
|
||||
-- !query 14 output
|
||||
1 1
|
||||
|
||||
|
||||
-- !query 15
|
||||
SELECT a AS k, COUNT(b) FROM testData GROUP BY k
|
||||
-- !query 15 schema
|
||||
struct<k:int,count(b):bigint>
|
||||
-- !query 15 output
|
||||
1 2
|
||||
2 2
|
||||
3 2
|
||||
NULL 1
|
||||
|
||||
|
||||
-- !query 16
|
||||
SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1
|
||||
-- !query 16 schema
|
||||
struct<k:int,count(b):bigint>
|
||||
-- !query 16 output
|
||||
2 2
|
||||
3 2
|
||||
|
||||
|
||||
-- !query 17
|
||||
SELECT COUNT(b) AS k FROM testData GROUP BY k
|
||||
-- !query 17 schema
|
||||
struct<>
|
||||
-- !query 17 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`);
|
||||
|
||||
|
||||
-- !query 18
|
||||
CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
|
||||
(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v)
|
||||
-- !query 18 schema
|
||||
struct<>
|
||||
-- !query 18 output
|
||||
|
||||
|
||||
|
||||
-- !query 19
|
||||
SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a
|
||||
-- !query 19 schema
|
||||
struct<>
|
||||
-- !query 19 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
|
||||
|
||||
|
||||
-- !query 20
|
||||
set spark.sql.groupByAliases=false
|
||||
-- !query 20 schema
|
||||
struct<key:string,value:string>
|
||||
-- !query 20 output
|
||||
spark.sql.groupByAliases false
|
||||
|
||||
|
||||
-- !query 21
|
||||
SELECT a AS k, COUNT(b) FROM testData GROUP BY k
|
||||
-- !query 21 schema
|
||||
struct<>
|
||||
-- !query 21 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47
|
||||
|
|
Loading…
Reference in a new issue