From 7be8d8a164a2bc12887c83361af401d233de3397 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 18 Jun 2021 21:48:44 +0800 Subject: [PATCH] [SPARK-35185][SQL] Improve Distinct statistics estimation ### What changes were proposed in this pull request? This PR improves `Distinct` statistics estimation by rewrite it to `Aggregate`. ### Why are the changes needed? 1. The current implementation will lack column statistics. 2. Some rules before the `ReplaceDistinctWithAggregate` may use it. For example: https://github.com/apache/spark/pull/31113/files#diff-11264d807efa58054cca2d220aae8fba644ee0f0f2a4722c46d52828394846efR1808 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #32291 from wangyum/SPARK-35185. Authored-by: Yuming Wang Signed-off-by: Yuming Wang --- .../statsEstimation/BasicStatsPlanVisitor.scala | 5 ++++- .../SizeInBytesOnlyStatsPlanVisitor.scala | 2 +- .../statsEstimation/BasicStatsEstimationSuite.scala | 10 +++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index d0c9b4c3c3..67e4ad0f20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -43,7 +43,10 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { AggregateEstimation.estimate(p).getOrElse(fallback(p)) } - override def visitDistinct(p: Distinct): Statistics = default(p) + override def visitDistinct(p: Distinct): Statistics = { + val child = p.child + visitAggregate(Aggregate(child.output, child.output, child)) + } override def visitExcept(p: Except): Statistics = fallback(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 6d5d2f7d2c..5c5ccdea55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -67,7 +67,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } } - override def visitDistinct(p: Distinct): Statistics = default(p) + override def visitDistinct(p: Distinct): Statistics = visitUnaryNode(p) override def visitExcept(p: Except): Statistics = p.left.stats.copy() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 28c00b9eff..31e289e052 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -291,7 +291,6 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { test("SPARK-34121: Intersect operator missing rowCount when enable CBO") { val intersect = Intersect(plan, plan, false) - val childrenSize = intersect.children.size val sizeInBytes = plan.size.get val rowCount = Some(plan.rowCount) checkStats( @@ -300,6 +299,15 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = Statistics(sizeInBytes = sizeInBytes)) } + test("SPARK-35185: Improve Distinct statistics estimation") { + val distinct = Distinct(plan) + val sizeInBytes = plan.size.get + checkStats( + distinct, + expectedStatsCboOn = Statistics(sizeInBytes, Some(plan.rowCount), plan.attributeStats), + expectedStatsCboOff = Statistics(sizeInBytes = sizeInBytes)) + } + test("row size and column stats estimation for sort") { val columnInfo = AttributeMap( Seq(