[SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation

This PR adds a new rule `CheckAggregation` to the analyzer to provide better error message for non-aggregate attributes with aggregation.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #2774 from liancheng/non-aggregate-attr and squashes the following commits:

5246004 [Cheng Lian] Passes test suites
bf1878d [Cheng Lian] Adds checks for non-aggregate attributes with aggregation
This commit is contained in:
Cheng Lian 2014-10-13 13:36:39 -07:00 committed by Michael Armbrust
parent 2ac40da3f9
commit 56102dc2d8
2 changed files with 57 additions and 5 deletions

View file

@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution),
CheckResolution,
CheckAggregation),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)
@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}
/**
* Checks for non-aggregated attributes with aggregation
*/
object CheckAggregation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AggregateExpression => true
case e: Attribute => groupingExprs.contains(e)
case e if groupingExprs.contains(e) => true
case e if e.references.isEmpty => true
case e => e.children.forall(isValidAggregateExpression)
}
aggregateExprs.foreach { e =>
if (!isValidAggregateExpression(e)) {
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}
}
aggregatePlan
}
}
}
/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) => {
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
}
}
protected def containsAggregate(condition: Expression): Boolean =
condition
.collect { case ae: AggregateExpression => ae }

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
}
test("throw errors for non-aggregate attributes with aggregation") {
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
val logicalPlan = sql(query).queryExecution.logical
if (isInvalidQuery) {
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
assert(
e.getMessage.startsWith("Expression not in GROUP BY"),
"Non-aggregate attribute(s) not detected\n" + logicalPlan)
} else {
// Should not throw
sql(query).queryExecution.analyzed
}
}
checkAggregation("SELECT key, COUNT(*) FROM testData")
checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
}
}