From 2287f3d0b85730995bedc489a017de5700d6e1e4 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 1 Apr 2017 22:19:08 +0800 Subject: [PATCH] [SPARK-20186][SQL] BroadcastHint should use child's stats ## What changes were proposed in this pull request? `BroadcastHint` should use child's statistics and set `isBroadcastable` to true. ## How was this patch tested? Added a new stats estimation test for `BroadcastHint`. Author: wangzhenhua Closes #17504 from wzhfy/broadcastHintEstimation. --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../BasicStatsEstimationSuite.scala | 21 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5cbf263d1c..19db42c808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { // set isBroadcastable to true so the child will be broadcasted override def computeStats(conf: CatalystConf): Statistics = - super.computeStats(conf).copy(isBroadcastable = true) + child.stats(conf).copy(isBroadcastable = true) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index e5dc811c8b..0d92c1e355 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { // row count * (overhead + column size) size = Some(10 * (8 + 4))) + test("BroadcastHint estimation") { + val filter = Filter(Literal(true), plan) + val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false, + rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) + val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false) + checkStats( + filter, + expectedStatsCboOn = filterStatsCboOn, + expectedStatsCboOff = filterStatsCboOff) + + val broadcastHint = BroadcastHint(filter) + checkStats( + broadcastHint, + expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), + expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true)) + } + test("limit estimation: limit < child's rowCount") { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) @@ -97,8 +114,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + plan.invalidateStatsCache() assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) }