[SPARK-35362][SQL] Update null count in the column stats for UNION operator stats estimation
### What changes were proposed in this pull request? Updating column stats for Union operator stats estimation ### Why are the changes needed? This is a followup PR to update the null count also in the Union stats operator estimation. https://github.com/apache/spark/pull/30334 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Updated UTs, manual testing Closes #32494 from shahidki31/shahid/updateNullCountForUnion. Lead-authored-by: shahid <shahidki31@gmail.com> Co-authored-by: Shahid <shahidki31@gmail.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
9283bebbbd
commit
12142130cd
|
@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
|
|||
attr: Attribute,
|
||||
isNull: Boolean,
|
||||
update: Boolean): Option[Double] = {
|
||||
if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) {
|
||||
if (!colStatsMap.contains(attr) || colStatsMap(attr).nullCount.isEmpty) {
|
||||
logDebug("[CBO] No statistics for " + attr)
|
||||
return None
|
||||
}
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics, Union}
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -70,8 +68,27 @@ object UnionEstimation {
|
|||
None
|
||||
}
|
||||
|
||||
val unionOutput = union.output
|
||||
val newMinMaxStats = computeMinMaxStats(union)
|
||||
val newNullCountStats = computeNullCountStats(union)
|
||||
val newAttrStats = {
|
||||
val baseStats = AttributeMap(newMinMaxStats)
|
||||
val overwriteStats = newNullCountStats.map { case attrStat@(attr, stat) =>
|
||||
baseStats.get(attr).map { baseStat =>
|
||||
attr -> baseStat.copy(nullCount = stat.nullCount)
|
||||
}.getOrElse(attrStat)
|
||||
}
|
||||
AttributeMap(newMinMaxStats ++ overwriteStats)
|
||||
}
|
||||
|
||||
Some(
|
||||
Statistics(
|
||||
sizeInBytes = sizeInBytes,
|
||||
rowCount = outputRows,
|
||||
attributeStats = newAttrStats))
|
||||
}
|
||||
|
||||
private def computeMinMaxStats(union: Union): Seq[(Attribute, ColumnStat)] = {
|
||||
val unionOutput = union.output
|
||||
val attrToComputeMinMaxStats = union.children.map(_.output).transpose.zipWithIndex.filter {
|
||||
case (attrs, outputIndex) => isTypeSupported(unionOutput(outputIndex).dataType) &&
|
||||
// checks if all the children has min/max stats for an attribute
|
||||
|
@ -81,40 +98,50 @@ object UnionEstimation {
|
|||
attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
|
||||
}
|
||||
}
|
||||
|
||||
val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
|
||||
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
|
||||
attrToComputeMinMaxStats.foreach {
|
||||
case (attrs, outputIndex) =>
|
||||
val dataType = unionOutput(outputIndex).dataType
|
||||
val statComparator = createStatComparator(dataType)
|
||||
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
|
||||
case ((minVal, maxVal), (attr, childIndex)) =>
|
||||
val colStat = union.children(childIndex).stats.attributeStats(attr)
|
||||
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
|
||||
colStat.min
|
||||
} else {
|
||||
minVal
|
||||
}
|
||||
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
|
||||
colStat.max
|
||||
} else {
|
||||
maxVal
|
||||
}
|
||||
(min, max)
|
||||
attrToComputeMinMaxStats.map {
|
||||
case (attrs, outputIndex) =>
|
||||
val dataType = unionOutput(outputIndex).dataType
|
||||
val statComparator = createStatComparator(dataType)
|
||||
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
|
||||
case ((minVal, maxVal), (attr, childIndex)) =>
|
||||
val colStat = union.children(childIndex).stats.attributeStats(attr)
|
||||
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
|
||||
colStat.min
|
||||
} else {
|
||||
minVal
|
||||
}
|
||||
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
|
||||
outputAttrStats += unionOutput(outputIndex) -> newStat
|
||||
}
|
||||
AttributeMap(outputAttrStats.toSeq)
|
||||
} else {
|
||||
AttributeMap.empty[ColumnStat]
|
||||
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
|
||||
colStat.max
|
||||
} else {
|
||||
maxVal
|
||||
}
|
||||
(min, max)
|
||||
}
|
||||
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
|
||||
unionOutput(outputIndex) -> newStat
|
||||
}
|
||||
}
|
||||
|
||||
Some(
|
||||
Statistics(
|
||||
sizeInBytes = sizeInBytes,
|
||||
rowCount = outputRows,
|
||||
attributeStats = newAttrStats))
|
||||
private def computeNullCountStats(union: Union): Seq[(Attribute, ColumnStat)] = {
|
||||
val unionOutput = union.output
|
||||
val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter {
|
||||
case (attrs, _) => attrs.zipWithIndex.forall {
|
||||
case (attr, childIndex) =>
|
||||
val attrStats = union.children(childIndex).stats.attributeStats
|
||||
attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
|
||||
}
|
||||
}
|
||||
attrToComputeNullCount.map {
|
||||
case (attrs, outputIndex) =>
|
||||
val firstStat = union.children.head.stats.attributeStats(attrs.head)
|
||||
val firstNullCount = firstStat.nullCount.get
|
||||
val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
|
||||
case (totalNullCount, (attr, childIndex)) =>
|
||||
val colStat = union.children(childIndex).stats.attributeStats(attr)
|
||||
totalNullCount + colStat.nullCount.get
|
||||
}
|
||||
val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
|
||||
unionOutput(outputIndex) -> newStat
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -273,7 +273,7 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
|
|||
val rowCount = Some(plan.rowCount * childrenSize)
|
||||
val attributeStats = AttributeMap(
|
||||
Seq(
|
||||
attribute -> ColumnStat(min = Some(1), max = Some(10))))
|
||||
attribute -> ColumnStat(min = Some(1), max = Some(10), nullCount = Some(0))))
|
||||
checkStats(
|
||||
union,
|
||||
expectedStatsCboOn = Statistics(sizeInBytes = sizeInBytes,
|
||||
|
|
|
@ -68,14 +68,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
|
|||
distinctCount = Some(2),
|
||||
min = Some(1),
|
||||
max = Some(4),
|
||||
nullCount = Some(0),
|
||||
nullCount = Some(1),
|
||||
avgLen = Some(4),
|
||||
maxLen = Some(4)),
|
||||
attrDouble -> ColumnStat(
|
||||
distinctCount = Some(2),
|
||||
min = Some(5.0),
|
||||
max = Some(4.0),
|
||||
nullCount = Some(0),
|
||||
nullCount = Some(2),
|
||||
avgLen = Some(4),
|
||||
maxLen = Some(4)),
|
||||
attrShort -> ColumnStat(min = Some(s1), max = Some(s2)),
|
||||
|
@ -96,14 +96,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
|
|||
distinctCount = Some(2),
|
||||
min = Some(3),
|
||||
max = Some(6),
|
||||
nullCount = Some(0),
|
||||
nullCount = Some(1),
|
||||
avgLen = Some(8),
|
||||
maxLen = Some(8)),
|
||||
AttributeReference("cdouble1", DoubleType)() -> ColumnStat(
|
||||
distinctCount = Some(2),
|
||||
min = Some(2.0),
|
||||
max = Some(7.0),
|
||||
nullCount = Some(0),
|
||||
nullCount = Some(2),
|
||||
avgLen = Some(8),
|
||||
maxLen = Some(8)),
|
||||
AttributeReference("cshort1", ShortType)() -> ColumnStat(min = Some(s3), max = Some(s4)),
|
||||
|
@ -139,8 +139,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
|
|||
rowCount = Some(4),
|
||||
attributeStats = AttributeMap(
|
||||
Seq(
|
||||
attrInt -> ColumnStat(min = Some(1), max = Some(6)),
|
||||
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0)),
|
||||
attrInt -> ColumnStat(min = Some(1), max = Some(6), nullCount = Some(2)),
|
||||
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0), nullCount = Some(4)),
|
||||
attrShort -> ColumnStat(min = Some(s1), max = Some(s4)),
|
||||
attrLong -> ColumnStat(min = Some(1L), max = Some(6L)),
|
||||
attrByte -> ColumnStat(min = Some(b1), max = Some(b4)),
|
||||
|
@ -188,7 +188,58 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
|
|||
|
||||
val union = Union(Seq(child1, child2))
|
||||
|
||||
val expectedStats = logical.Statistics(sizeInBytes = 2 * 1024, rowCount = Some(4))
|
||||
// Only null count is present in the attribute stats
|
||||
val expectedStats = logical.Statistics(
|
||||
sizeInBytes = 2 * 1024,
|
||||
rowCount = Some(4),
|
||||
attributeStats = AttributeMap(
|
||||
Seq(attrInt -> ColumnStat(nullCount = Some(0)))))
|
||||
assert(union.stats === expectedStats)
|
||||
}
|
||||
|
||||
test("col stats estimation when null count stats are not present for one child") {
|
||||
val sz = Some(BigInt(1024))
|
||||
val attrInt = AttributeReference("cint", IntegerType)()
|
||||
val columnInfo = AttributeMap(
|
||||
Seq(
|
||||
attrInt -> ColumnStat(
|
||||
distinctCount = Some(2),
|
||||
min = Some(1),
|
||||
max = Some(2),
|
||||
nullCount = Some(2),
|
||||
avgLen = Some(4),
|
||||
maxLen = Some(4))))
|
||||
|
||||
// No null count
|
||||
val columnInfo1 = AttributeMap(
|
||||
Seq(
|
||||
AttributeReference("cint1", IntegerType)() -> ColumnStat(
|
||||
distinctCount = Some(2),
|
||||
min = Some(3),
|
||||
max = Some(4),
|
||||
avgLen = Some(8),
|
||||
maxLen = Some(8))))
|
||||
|
||||
val child1 = StatsTestPlan(
|
||||
outputList = columnInfo.keys.toSeq,
|
||||
rowCount = 2,
|
||||
attributeStats = columnInfo,
|
||||
size = sz)
|
||||
|
||||
val child2 = StatsTestPlan(
|
||||
outputList = columnInfo1.keys.toSeq,
|
||||
rowCount = 2,
|
||||
attributeStats = columnInfo1,
|
||||
size = sz)
|
||||
|
||||
val union = Union(Seq(child1, child2))
|
||||
|
||||
// Null count should not present in the stats.
|
||||
val expectedStats = logical.Statistics(
|
||||
sizeInBytes = 2 * 1024,
|
||||
rowCount = Some(4),
|
||||
attributeStats = AttributeMap(
|
||||
Seq(attrInt -> ColumnStat(min = Some(1), max = Some(4), nullCount = None))))
|
||||
assert(union.stats === expectedStats)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue