diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 111c594a53..7ef22fafc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -39,8 +39,16 @@ object AggregateEstimation { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( - (res, expr) => res * - childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get) + (res, expr) => { + val columnStat = childStats.attributeStats(expr.asInstanceOf[Attribute]) + val distinctCount = columnStat.distinctCount.get + val distinctValue: BigInt = if (distinctCount == 0 && columnStat.nullCount.get > 0) { + 1 + } else { + distinctCount + } + res * distinctValue + }) outputRows = if (agg.groupingExpressions.isEmpty) { // If there's no group-by columns, the output is a single row containing values of aggregate diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 8213d568fe..6bdf8cd88c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -38,7 +38,9 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, - nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key32") -> ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(4), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -92,6 +94,14 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { expectedOutputRowCount = 0) } + test("group-by column with only null value") { + checkAggStats( + tableColumns = Seq("key22", "key32"), + tableRowCount = 6, + groupByColumns = Seq("key22", "key32"), + expectedOutputRowCount = nameToColInfo("key22")._2.distinctCount.get) + } + test("non-cbo estimation") { val attributes = Seq("key12").map(nameToAttr) val child = StatsTestPlan(