[SPARK-20186][SQL] BroadcastHint should use child's stats
## What changes were proposed in this pull request? `BroadcastHint` should use child's statistics and set `isBroadcastable` to true. ## How was this patch tested? Added a new stats estimation test for `BroadcastHint`. Author: wangzhenhua <wangzhenhua@huawei.com> Closes #17504 from wzhfy/broadcastHintEstimation.
This commit is contained in:
parent
89d6822f72
commit
2287f3d0b8
|
@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
|
||||||
|
|
||||||
// set isBroadcastable to true so the child will be broadcasted
|
// set isBroadcastable to true so the child will be broadcasted
|
||||||
override def computeStats(conf: CatalystConf): Statistics =
|
override def computeStats(conf: CatalystConf): Statistics =
|
||||||
super.computeStats(conf).copy(isBroadcastable = true)
|
child.stats(conf).copy(isBroadcastable = true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
|
||||||
// row count * (overhead + column size)
|
// row count * (overhead + column size)
|
||||||
size = Some(10 * (8 + 4)))
|
size = Some(10 * (8 + 4)))
|
||||||
|
|
||||||
|
test("BroadcastHint estimation") {
|
||||||
|
val filter = Filter(Literal(true), plan)
|
||||||
|
val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
|
||||||
|
rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
|
||||||
|
val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
|
||||||
|
checkStats(
|
||||||
|
filter,
|
||||||
|
expectedStatsCboOn = filterStatsCboOn,
|
||||||
|
expectedStatsCboOff = filterStatsCboOff)
|
||||||
|
|
||||||
|
val broadcastHint = BroadcastHint(filter)
|
||||||
|
checkStats(
|
||||||
|
broadcastHint,
|
||||||
|
expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
|
||||||
|
expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
|
||||||
|
}
|
||||||
|
|
||||||
test("limit estimation: limit < child's rowCount") {
|
test("limit estimation: limit < child's rowCount") {
|
||||||
val localLimit = LocalLimit(Literal(2), plan)
|
val localLimit = LocalLimit(Literal(2), plan)
|
||||||
val globalLimit = GlobalLimit(Literal(2), plan)
|
val globalLimit = GlobalLimit(Literal(2), plan)
|
||||||
|
@ -97,8 +114,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
|
||||||
plan: LogicalPlan,
|
plan: LogicalPlan,
|
||||||
expectedStatsCboOn: Statistics,
|
expectedStatsCboOn: Statistics,
|
||||||
expectedStatsCboOff: Statistics): Unit = {
|
expectedStatsCboOff: Statistics): Unit = {
|
||||||
assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
|
|
||||||
// Invalidate statistics
|
// Invalidate statistics
|
||||||
|
plan.invalidateStatsCache()
|
||||||
|
assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
|
||||||
|
|
||||||
plan.invalidateStatsCache()
|
plan.invalidateStatsCache()
|
||||||
assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
|
assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue