[SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation
This PR adds a new rule `CheckAggregation` to the analyzer to provide better error message for non-aggregate attributes with aggregation. Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #2774 from liancheng/non-aggregate-attr and squashes the following commits: 5246004 [Cheng Lian] Passes test suites bf1878d [Cheng Lian] Adds checks for non-aggregate attributes with aggregation
This commit is contained in:
parent
2ac40da3f9
commit
56102dc2d8
|
@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
|
|||
typeCoercionRules ++
|
||||
extendedRules : _*),
|
||||
Batch("Check Analysis", Once,
|
||||
CheckResolution),
|
||||
CheckResolution,
|
||||
CheckAggregation),
|
||||
Batch("AnalysisOperators", fixedPoint,
|
||||
EliminateAnalysisOperators)
|
||||
)
|
||||
|
@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks for non-aggregated attributes with aggregation
|
||||
*/
|
||||
object CheckAggregation extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
plan.transform {
|
||||
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
|
||||
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
|
||||
case _: AggregateExpression => true
|
||||
case e: Attribute => groupingExprs.contains(e)
|
||||
case e if groupingExprs.contains(e) => true
|
||||
case e if e.references.isEmpty => true
|
||||
case e => e.children.forall(isValidAggregateExpression)
|
||||
}
|
||||
|
||||
aggregateExprs.foreach { e =>
|
||||
if (!isValidAggregateExpression(e)) {
|
||||
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
|
||||
}
|
||||
}
|
||||
|
||||
aggregatePlan
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
|
||||
*/
|
||||
|
@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
|
|||
*/
|
||||
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
|
||||
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
|
||||
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
|
||||
if aggregate.resolved && containsAggregate(havingCondition) => {
|
||||
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
|
||||
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
|
||||
|
||||
|
||||
Project(aggregate.output,
|
||||
Filter(evaluatedCondition.toAttribute,
|
||||
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
protected def containsAggregate(condition: Expression): Boolean =
|
||||
condition
|
||||
.collect { case ae: AggregateExpression => ae }
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql
|
|||
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
|
||||
import org.apache.spark.sql.test._
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
|
|||
checkAnswer(
|
||||
sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
|
||||
}
|
||||
|
||||
test("throw errors for non-aggregate attributes with aggregation") {
|
||||
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
|
||||
val logicalPlan = sql(query).queryExecution.logical
|
||||
|
||||
if (isInvalidQuery) {
|
||||
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
|
||||
assert(
|
||||
e.getMessage.startsWith("Expression not in GROUP BY"),
|
||||
"Non-aggregate attribute(s) not detected\n" + logicalPlan)
|
||||
} else {
|
||||
// Should not throw
|
||||
sql(query).queryExecution.analyzed
|
||||
}
|
||||
}
|
||||
|
||||
checkAggregation("SELECT key, COUNT(*) FROM testData")
|
||||
checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
|
||||
|
||||
checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
|
||||
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
|
||||
|
||||
checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
|
||||
checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue