[SPARK-14471][SQL] Aliases in SELECT could be used in GROUP BY

## What changes were proposed in this pull request?
This pr added a new rule in `Analyzer` to resolve aliases in `GROUP BY`.
The current master throws an exception if `GROUP BY` clauses have aliases in `SELECT`;
```
scala> spark.sql("select a a1, a1 + 1 as b, count(1) from t group by a1")
org.apache.spark.sql.AnalysisException: cannot resolve '`a1`' given input columns: [a]; line 1 pos 51;
'Aggregate ['a1], [a#83L AS a1#87L, ('a1 + 1) AS b#88, count(1) AS count(1)#90L]
+- SubqueryAlias t
   +- Project [id#80L AS a#83L]
      +- Range (0, 10, step=1, splits=Some(8))

  at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:77)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:74)
  at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289)
```

## How was this patch tested?
Added tests in `SQLQuerySuite` and `SQLQueryTestSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17191 from maropu/SPARK-14471.
This commit is contained in:
Takeshi Yamamuro 2017-04-28 14:41:53 +08:00 committed by Wenchen Fan
parent e3c8160433
commit 59e3a56444
6 changed files with 156 additions and 32 deletions

View file

@ -136,6 +136,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveAggAliasInGroupBy ::
ResolveMissingReferences ::
ExtractGenerator ::
ResolveGenerate ::
@ -172,7 +173,7 @@ class Analyzer(
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
*/
object CTESubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case With(child, relations) =>
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) =>
@ -200,7 +201,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
@ -242,7 +243,7 @@ class Analyzer(
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
@ -614,7 +615,7 @@ class Analyzer(
case _ => plan
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View =>
@ -786,7 +787,7 @@ class Analyzer(
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it.
@ -844,11 +845,10 @@ class Analyzer(
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
q.transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@ -961,7 +961,7 @@ class Analyzer(
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
@ -997,6 +997,27 @@ class Analyzer(
}
}
/**
* Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses.
* This rule is expected to run after [[ResolveReferences]] applied.
*/
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
// This is a strict check though, we put this to apply the rule only in alias expressions
def notResolvableByChild(attrName: String): Boolean =
!child.output.exists(a => resolver(a.name, attrName))
agg.copy(groupingExpressions = groups.map {
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
case e => e
})
}
}
/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
@ -1006,7 +1027,7 @@ class Analyzer(
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
@ -1130,7 +1151,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
@ -1469,7 +1490,7 @@ class Analyzer(
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@ -1484,7 +1505,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
@ -1510,7 +1531,7 @@ class Analyzer(
* underlying aggregate operator and then projected away after the original operator.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved =>
@ -1682,7 +1703,7 @@ class Analyzer(
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
@ -1740,7 +1761,7 @@ class Analyzer(
* that wrap the [[Generator]].
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@ -2057,7 +2078,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p
case f: Filter => f
@ -2102,7 +2123,7 @@ class Analyzer(
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp {
@ -2167,7 +2188,7 @@ class Analyzer(
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
@ -2232,7 +2253,7 @@ class Analyzer(
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
@ -2318,7 +2339,7 @@ class Analyzer(
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
@ -2352,7 +2373,7 @@ class Analyzer(
"type of the field in the target object")
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
@ -2406,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
case other => trimAliases(other)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
@ -2474,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =

View file

@ -421,6 +421,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases")
.doc("When true, aliases in a select list can be used in group by clauses. When false, " +
"an analysis exception is thrown in the case.")
.booleanConf
.createWithDefault(true)
// The output committer class used by data sources. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
val OUTPUT_COMMITTER_CLASS =
@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging {
def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES)
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)

View file

@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
-- group by ordinal followed by having
select count(a), a from (select 1 as a) tmp group by 2 having a > 0;
-- mixed cases: group-by ordinals and aliases
select a, a AS k, count(b) from data group by k, 1;
-- turn of group by ordinal
set spark.sql.groupByOrdinal=false;

View file

@ -35,3 +35,21 @@ FROM testData;
-- Aggregate with foldable input and multiple distinct groups.
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
-- Aliases in SELECT could be used in GROUP BY
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1;
-- Aggregate functions cannot be used in GROUP BY
SELECT COUNT(b) AS k FROM testData GROUP BY k;
-- Test data.
CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v);
SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a;
-- turn off group by aliases
set spark.sql.groupByAliases=false;
-- Check analysis exceptions
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 19
-- Number of queries: 20
-- !query 0
@ -173,16 +173,26 @@ struct<count(a):bigint,a:int>
-- !query 17
set spark.sql.groupByOrdinal=false
select a, a AS k, count(b) from data group by k, 1
-- !query 17 schema
struct<key:string,value:string>
struct<a:int,k:int,count(b):bigint>
-- !query 17 output
spark.sql.groupByOrdinal false
1 1 2
2 2 2
3 3 2
-- !query 18
select sum(b) from data group by -1
set spark.sql.groupByOrdinal=false
-- !query 18 schema
struct<sum(b):bigint>
struct<key:string,value:string>
-- !query 18 output
spark.sql.groupByOrdinal false
-- !query 19
select sum(b) from data group by -1
-- !query 19 schema
struct<sum(b):bigint>
-- !query 19 output
9

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 22
-- !query 0
@ -139,3 +139,67 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS
struct<count(DISTINCT b):bigint,count(DISTINCT b, c):bigint>
-- !query 14 output
1 1
-- !query 15
SELECT a AS k, COUNT(b) FROM testData GROUP BY k
-- !query 15 schema
struct<k:int,count(b):bigint>
-- !query 15 output
1 2
2 2
3 2
NULL 1
-- !query 16
SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1
-- !query 16 schema
struct<k:int,count(b):bigint>
-- !query 16 output
2 2
3 2
-- !query 17
SELECT COUNT(b) AS k FROM testData GROUP BY k
-- !query 17 schema
struct<>
-- !query 17 output
org.apache.spark.sql.AnalysisException
aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`);
-- !query 18
CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v)
-- !query 18 schema
struct<>
-- !query 18 output
-- !query 19
SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a
-- !query 19 schema
struct<>
-- !query 19 output
org.apache.spark.sql.AnalysisException
expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
-- !query 20
set spark.sql.groupByAliases=false
-- !query 20 schema
struct<key:string,value:string>
-- !query 20 output
spark.sql.groupByAliases false
-- !query 21
SELECT a AS k, COUNT(b) FROM testData GROUP BY k
-- !query 21 schema
struct<>
-- !query 21 output
org.apache.spark.sql.AnalysisException
cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47