[SPARK-28646][SQL] Fix bug of Count so as consistent with mainstream databases
### What changes were proposed in this pull request? Currently, Spark allows calls to `count` even for non parameterless aggregate function. For example, the following query actually works: `SELECT count() FROM tenk1;` On the other hand, mainstream databases will throw an error. **Oracle** `> ORA-00909: invalid number of arguments` **PgSQL** `ERROR: count(*) must be used to call a parameterless aggregate function` **MySQL** `> 1064 - You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near ')` ### Why are the changes needed? Fix a bug so that consistent with mainstream databases. There is an example query output with/without this fix. `SELECT count() FROM testData;` The output before this fix: `0` The output after this fix: ``` org.apache.spark.sql.AnalysisException cannot resolve 'count()' due to data type mismatch: count requires at least one argument.; line 1 pos 7 ``` ### Does this PR introduce _any_ user-facing change? Yes. If not specify parameter for `count`, will throw an error. ### How was this patch tested? Jenkins test. Closes #30541 from beliefer/SPARK-28646. Lead-authored-by: gengjiaan <gengjiaan@360.cn> Co-authored-by: beliefer <beliefer@163.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
225c2e2815
commit
b665d58819
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.expressions.aggregate
|
package org.apache.spark.sql.catalyst.expressions.aggregate
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
@ -43,11 +44,20 @@ import org.apache.spark.sql.types._
|
||||||
since = "1.0.0")
|
since = "1.0.0")
|
||||||
// scalastyle:on line.size.limit
|
// scalastyle:on line.size.limit
|
||||||
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
|
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
|
||||||
|
|
||||||
override def nullable: Boolean = false
|
override def nullable: Boolean = false
|
||||||
|
|
||||||
// Return data type.
|
// Return data type.
|
||||||
override def dataType: DataType = LongType
|
override def dataType: DataType = LongType
|
||||||
|
|
||||||
|
override def checkInputDataTypes(): TypeCheckResult = {
|
||||||
|
if (children.isEmpty) {
|
||||||
|
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument.")
|
||||||
|
} else {
|
||||||
|
TypeCheckResult.TypeCheckSuccess
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected lazy val count = AttributeReference("count", LongType, nullable = false)()
|
protected lazy val count = AttributeReference("count", LongType, nullable = false)()
|
||||||
|
|
||||||
override lazy val aggBufferAttributes = count :: Nil
|
override lazy val aggBufferAttributes = count :: Nil
|
||||||
|
|
|
@ -35,3 +35,6 @@ SELECT count(DISTINCT a), count(DISTINCT 3,2) FROM testData;
|
||||||
SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 2,3) FROM testData;
|
SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 2,3) FROM testData;
|
||||||
SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 3,2) FROM testData;
|
SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 3,2) FROM testData;
|
||||||
SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData;
|
SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData;
|
||||||
|
|
||||||
|
-- count without expressions
|
||||||
|
SELECT count() FROM testData;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
-- Automatically generated by SQLQueryTestSuite
|
-- Automatically generated by SQLQueryTestSuite
|
||||||
-- Number of queries: 13
|
-- Number of queries: 14
|
||||||
|
|
||||||
|
|
||||||
-- !query
|
-- !query
|
||||||
|
@ -117,3 +117,12 @@ SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData
|
||||||
struct<count(DISTINCT 0.8):bigint,percentile_approx(DISTINCT a, CAST(0.8 AS DOUBLE), 10000):int>
|
struct<count(DISTINCT 0.8):bigint,percentile_approx(DISTINCT a, CAST(0.8 AS DOUBLE), 10000):int>
|
||||||
-- !query output
|
-- !query output
|
||||||
1 2
|
1 2
|
||||||
|
|
||||||
|
|
||||||
|
-- !query
|
||||||
|
SELECT count() FROM testData
|
||||||
|
-- !query schema
|
||||||
|
struct<>
|
||||||
|
-- !query output
|
||||||
|
org.apache.spark.sql.AnalysisException
|
||||||
|
cannot resolve 'count()' due to data type mismatch: count requires at least one argument.; line 1 pos 7
|
Loading…
Reference in a new issue