diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index e043c81975..e4488b26f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -43,11 +44,20 @@ import org.apache.spark.sql.types._ since = "1.0.0") // scalastyle:on line.size.limit case class Count(children: Seq[Expression]) extends DeclarativeAggregate { + override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil diff --git a/sql/core/src/test/resources/sql-tests/inputs/count.sql b/sql/core/src/test/resources/sql-tests/inputs/count.sql index 203f04c589..fc0d66258e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/count.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/count.sql @@ -35,3 +35,6 @@ SELECT count(DISTINCT a), count(DISTINCT 3,2) FROM testData; SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 2,3) FROM testData; SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 3,2) FROM testData; SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData; + +-- count without expressions +SELECT count() FROM testData; diff --git a/sql/core/src/test/resources/sql-tests/results/count.sql.out b/sql/core/src/test/resources/sql-tests/results/count.sql.out index c0cdd0d697..64614b5b67 100644 --- a/sql/core/src/test/resources/sql-tests/results/count.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/count.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 14 -- !query @@ -116,4 +116,13 @@ SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData -- !query schema struct -- !query output -1 2 \ No newline at end of file +1 2 + + +-- !query +SELECT count() FROM testData +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'count()' due to data type mismatch: count requires at least one argument.; line 1 pos 7 \ No newline at end of file