[SPARK-27255][SQL] Report error when illegal expressions are hosted by a plan operator.
## What changes were proposed in this pull request? In the PR, we raise an AnalysisError when we detect the presense of aggregate expressions in where clause. Here is the problem description from the JIRA. Aggregate functions should not be allowed in WHERE clause. But Spark SQL throws an exception when generating codes. It is supposed to throw an exception during parsing or analyzing. Here is an example: ``` val df = spark.sql("select * from t where sum(ta) > 0") df.explain(true) df.show() ``` Resulting exception: ``` Exception in thread "main" java.lang.UnsupportedOperationException: Cannot generate code for expression: sum(cast(input[0, int, false] as bigint)) at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:291) at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:290) at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:87) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:138) at scala.Option.getOrElse(Option.scala:138) ``` Checked the behaviour of other database and all of them return an exception: **Postgress** ``` select * from foo where max(c1) > 0; Error ERROR: aggregate functions are not allowed in WHERE Position: 25 ``` **DB2** ``` db2 => select * from foo where max(c1) > 0; SQL0120N Invalid use of an aggregate function or OLAP function. ``` **Oracle** ``` select * from foo where max(c1) > 0; ORA-00934: group function is not allowed here ``` **MySql** ``` select * from foo where max(c1) > 0; Invalid use of group function ``` **Update** This PR has been enhanced to report error when expressions such as Aggregate, Window, Generate are hosted by operators where they are invalid. ## How was this patch tested? Added tests in AnalysisErrorSuite and group-by.sql Closes #24209 from dilipbiswal/SPARK-27255. Authored-by: Dilip Biswal <dbiswal@us.ibm.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
1d20d13149
commit
3286bff942
|
@ -178,7 +178,7 @@ trait CheckAnalysis extends PredicateHelper {
|
|||
s"of type ${condition.dataType.catalogString} is not a boolean.")
|
||||
|
||||
case Aggregate(groupingExprs, aggregateExprs, child) =>
|
||||
def isAggregateExpression(expr: Expression) = {
|
||||
def isAggregateExpression(expr: Expression): Boolean = {
|
||||
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
|
||||
}
|
||||
|
||||
|
@ -376,6 +376,25 @@ trait CheckAnalysis extends PredicateHelper {
|
|||
throw new IllegalStateException(
|
||||
"Internal error: logical hint operator should have been removed during analysis")
|
||||
|
||||
case f @ Filter(condition, _)
|
||||
if PlanHelper.specialExpressionsInUnsupportedOperator(f).nonEmpty =>
|
||||
val invalidExprSqls = PlanHelper.specialExpressionsInUnsupportedOperator(f).map(_.sql)
|
||||
failAnalysis(
|
||||
s"""
|
||||
|Aggregate/Window/Generate expressions are not valid in where clause of the query.
|
||||
|Expression in where clause: [${condition.sql}]
|
||||
|Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin)
|
||||
|
||||
case other if PlanHelper.specialExpressionsInUnsupportedOperator(other).nonEmpty =>
|
||||
val invalidExprSqls =
|
||||
PlanHelper.specialExpressionsInUnsupportedOperator(other).map(_.sql)
|
||||
failAnalysis(
|
||||
s"""
|
||||
|The query operator `${other.nodeName}` contains one or more unsupported
|
||||
|expression types Aggregate, Window or Generate.
|
||||
|Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin
|
||||
)
|
||||
|
||||
case _ => // Analysis successful!
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,38 +43,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
|
|||
// - is still resolved
|
||||
// - only host special expressions in supported operators
|
||||
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || (plan.resolved && checkSpecialExpressionIntegrity(plan))
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if all operators in this plan hold structural integrity with regards to hosting special
|
||||
* expressions.
|
||||
* Returns true when all operators are integral.
|
||||
*/
|
||||
private def checkSpecialExpressionIntegrity(plan: LogicalPlan): Boolean = {
|
||||
plan.find(specialExpressionInUnsupportedOperator).isEmpty
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if there's any expression in this query plan operator that is
|
||||
* - A WindowExpression but the plan is not Window
|
||||
* - An AggregateExpresion but the plan is not Aggregate or Window
|
||||
* - A Generator but the plan is not Generate
|
||||
* Returns true when this operator breaks structural integrity with one of the cases above.
|
||||
*/
|
||||
private def specialExpressionInUnsupportedOperator(plan: LogicalPlan): Boolean = {
|
||||
val exprs = plan.expressions
|
||||
exprs.flatMap { root =>
|
||||
root.find {
|
||||
case e: WindowExpression
|
||||
if !plan.isInstanceOf[Window] => true
|
||||
case e: AggregateExpression
|
||||
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => true
|
||||
case e: Generator
|
||||
if !plan.isInstanceOf[Generate] => true
|
||||
case _ => false
|
||||
}
|
||||
}.nonEmpty
|
||||
!Utils.isTesting || (plan.resolved &&
|
||||
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty)
|
||||
}
|
||||
|
||||
protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, WindowExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
|
||||
/**
|
||||
* [[PlanHelper]] contains utility methods that can be used by Analyzer and Optimizer.
|
||||
* It can also be container of methods that are common across multiple rules in Analyzer
|
||||
* and Optimizer.
|
||||
*/
|
||||
object PlanHelper {
|
||||
/**
|
||||
* Check if there's any expression in this query plan operator that is
|
||||
* - A WindowExpression but the plan is not Window
|
||||
* - An AggregateExpresion but the plan is not Aggregate or Window
|
||||
* - A Generator but the plan is not Generate
|
||||
* Returns the list of invalid expressions that this operator hosts. This can happen when
|
||||
* 1. The input query from users contain invalid expressions.
|
||||
* Example : SELECT * FROM tab WHERE max(c1) > 0
|
||||
* 2. Query rewrites inadvertently produce plans that are invalid.
|
||||
*/
|
||||
def specialExpressionsInUnsupportedOperator(plan: LogicalPlan): Seq[Expression] = {
|
||||
val exprs = plan.expressions
|
||||
val invalidExpressions = exprs.flatMap { root =>
|
||||
root.collect {
|
||||
case e: WindowExpression
|
||||
if !plan.isInstanceOf[Window] => e
|
||||
case e: AggregateExpression
|
||||
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => e
|
||||
case e: Generator
|
||||
if !plan.isInstanceOf[Generate] => e
|
||||
}
|
||||
}
|
||||
invalidExpressions
|
||||
}
|
||||
}
|
|
@ -599,4 +599,12 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
assertAnalysisError(plan5,
|
||||
"Accessing outer query column is not allowed in" :: Nil)
|
||||
}
|
||||
|
||||
test("Error on filter condition containing aggregate expressions") {
|
||||
val a = AttributeReference("a", IntegerType)()
|
||||
val b = AttributeReference("b", IntegerType)()
|
||||
val plan = Filter('a === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b))
|
||||
assertAnalysisError(plan,
|
||||
"Aggregate/Window/Generate expressions are not valid in where clause of the query" :: Nil)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -141,3 +141,16 @@ SELECT every("true");
|
|||
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
|
||||
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
|
||||
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
|
||||
|
||||
-- Having referencing aggregate expressions is ok.
|
||||
SELECT count(*) FROM test_agg HAVING count(*) > 1L;
|
||||
SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true;
|
||||
|
||||
-- Aggrgate expressions can be referenced through an alias
|
||||
SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L;
|
||||
|
||||
-- Error when aggregate expressions are in where clause directly
|
||||
SELECT count(*) FROM test_agg WHERE count(*) > 1L;
|
||||
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L;
|
||||
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1;
|
||||
|
||||
|
|
|
@ -46,9 +46,10 @@ WHERE t1a IN (SELECT min(t2a)
|
|||
SELECT t1a
|
||||
FROM t1
|
||||
GROUP BY 1
|
||||
HAVING EXISTS (SELECT 1
|
||||
HAVING EXISTS (SELECT t2a
|
||||
FROM t2
|
||||
WHERE t2a < min(t1a + t2a));
|
||||
GROUP BY 1
|
||||
HAVING t2a < min(t1a + t2a));
|
||||
|
||||
-- TC 01.04
|
||||
-- Invalid due to mixure of outer and local references under an AggegatedExpression
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 46
|
||||
-- Number of queries: 52
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -459,3 +459,65 @@ struct<k:int,v:boolean,any(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RA
|
|||
5 NULL NULL
|
||||
5 false false
|
||||
5 true true
|
||||
|
||||
|
||||
-- !query 46
|
||||
SELECT count(*) FROM test_agg HAVING count(*) > 1L
|
||||
-- !query 46 schema
|
||||
struct<count(1):bigint>
|
||||
-- !query 46 output
|
||||
10
|
||||
|
||||
|
||||
-- !query 47
|
||||
SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true
|
||||
-- !query 47 schema
|
||||
struct<k:int,max(v):boolean>
|
||||
-- !query 47 output
|
||||
1 true
|
||||
2 true
|
||||
5 true
|
||||
|
||||
|
||||
-- !query 48
|
||||
SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L
|
||||
-- !query 48 schema
|
||||
struct<cnt:bigint>
|
||||
-- !query 48 output
|
||||
10
|
||||
|
||||
|
||||
-- !query 49
|
||||
SELECT count(*) FROM test_agg WHERE count(*) > 1L
|
||||
-- !query 49 schema
|
||||
struct<>
|
||||
-- !query 49 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
|
||||
Aggregate/Window/Generate expressions are not valid in where clause of the query.
|
||||
Expression in where clause: [(count(1) > 1L)]
|
||||
Invalid expressions: [count(1)];
|
||||
|
||||
|
||||
-- !query 50
|
||||
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L
|
||||
-- !query 50 schema
|
||||
struct<>
|
||||
-- !query 50 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
|
||||
Aggregate/Window/Generate expressions are not valid in where clause of the query.
|
||||
Expression in where clause: [((count(1) + 1L) > 1L)]
|
||||
Invalid expressions: [count(1)];
|
||||
|
||||
|
||||
-- !query 51
|
||||
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1
|
||||
-- !query 51 schema
|
||||
struct<>
|
||||
-- !query 51 output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
|
||||
Aggregate/Window/Generate expressions are not valid in where clause of the query.
|
||||
Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))]
|
||||
Invalid expressions: [count(1), max(test_agg.`k`)];
|
||||
|
|
|
@ -70,9 +70,10 @@ Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2
|
|||
SELECT t1a
|
||||
FROM t1
|
||||
GROUP BY 1
|
||||
HAVING EXISTS (SELECT 1
|
||||
HAVING EXISTS (SELECT t2a
|
||||
FROM t2
|
||||
WHERE t2a < min(t1a + t2a))
|
||||
GROUP BY 1
|
||||
HAVING t2a < min(t1a + t2a))
|
||||
-- !query 5 schema
|
||||
struct<>
|
||||
-- !query 5 output
|
||||
|
|
Loading…
Reference in a new issue