[SPARK-35362][SQL] Update null count in the column stats for UNION operator stats estimation

### What changes were proposed in this pull request?
Updating column stats for Union operator stats estimation
### Why are the changes needed?
This is a followup PR to update the null count also in the Union stats operator estimation. https://github.com/apache/spark/pull/30334

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Updated UTs, manual testing

Closes #32494 from shahidki31/shahid/updateNullCountForUnion.

Lead-authored-by: shahid <shahidki31@gmail.com>
Co-authored-by: Shahid <shahidki31@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
shahid 2021-05-19 21:23:19 +09:00 committed by Takeshi Yamamuro
parent 9283bebbbd
commit 12142130cd
4 changed files with 122 additions and 44 deletions

View file

@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
attr: Attribute,
isNull: Boolean,
update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) {
if (!colStatsMap.contains(attr) || colStatsMap(attr).nullCount.isEmpty) {
logDebug("[CBO] No statistics for " + attr)
return None
}

View file

@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics, Union}
import org.apache.spark.sql.types._
@ -70,8 +68,27 @@ object UnionEstimation {
None
}
val unionOutput = union.output
val newMinMaxStats = computeMinMaxStats(union)
val newNullCountStats = computeNullCountStats(union)
val newAttrStats = {
val baseStats = AttributeMap(newMinMaxStats)
val overwriteStats = newNullCountStats.map { case attrStat@(attr, stat) =>
baseStats.get(attr).map { baseStat =>
attr -> baseStat.copy(nullCount = stat.nullCount)
}.getOrElse(attrStat)
}
AttributeMap(newMinMaxStats ++ overwriteStats)
}
Some(
Statistics(
sizeInBytes = sizeInBytes,
rowCount = outputRows,
attributeStats = newAttrStats))
}
private def computeMinMaxStats(union: Union): Seq[(Attribute, ColumnStat)] = {
val unionOutput = union.output
val attrToComputeMinMaxStats = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, outputIndex) => isTypeSupported(unionOutput(outputIndex).dataType) &&
// checks if all the children has min/max stats for an attribute
@ -81,40 +98,50 @@ object UnionEstimation {
attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
}
}
val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
attrToComputeMinMaxStats.foreach {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
attrToComputeMinMaxStats.map {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
outputAttrStats += unionOutput(outputIndex) -> newStat
}
AttributeMap(outputAttrStats.toSeq)
} else {
AttributeMap.empty[ColumnStat]
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
unionOutput(outputIndex) -> newStat
}
}
Some(
Statistics(
sizeInBytes = sizeInBytes,
rowCount = outputRows,
attributeStats = newAttrStats))
private def computeNullCountStats(union: Union): Seq[(Attribute, ColumnStat)] = {
val unionOutput = union.output
val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, _) => attrs.zipWithIndex.forall {
case (attr, childIndex) =>
val attrStats = union.children(childIndex).stats.attributeStats
attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
}
}
attrToComputeNullCount.map {
case (attrs, outputIndex) =>
val firstStat = union.children.head.stats.attributeStats(attrs.head)
val firstNullCount = firstStat.nullCount.get
val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
case (totalNullCount, (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
totalNullCount + colStat.nullCount.get
}
val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
unionOutput(outputIndex) -> newStat
}
}
}

View file

@ -273,7 +273,7 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
val rowCount = Some(plan.rowCount * childrenSize)
val attributeStats = AttributeMap(
Seq(
attribute -> ColumnStat(min = Some(1), max = Some(10))))
attribute -> ColumnStat(min = Some(1), max = Some(10), nullCount = Some(0))))
checkStats(
union,
expectedStatsCboOn = Statistics(sizeInBytes = sizeInBytes,

View file

@ -68,14 +68,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
distinctCount = Some(2),
min = Some(1),
max = Some(4),
nullCount = Some(0),
nullCount = Some(1),
avgLen = Some(4),
maxLen = Some(4)),
attrDouble -> ColumnStat(
distinctCount = Some(2),
min = Some(5.0),
max = Some(4.0),
nullCount = Some(0),
nullCount = Some(2),
avgLen = Some(4),
maxLen = Some(4)),
attrShort -> ColumnStat(min = Some(s1), max = Some(s2)),
@ -96,14 +96,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
distinctCount = Some(2),
min = Some(3),
max = Some(6),
nullCount = Some(0),
nullCount = Some(1),
avgLen = Some(8),
maxLen = Some(8)),
AttributeReference("cdouble1", DoubleType)() -> ColumnStat(
distinctCount = Some(2),
min = Some(2.0),
max = Some(7.0),
nullCount = Some(0),
nullCount = Some(2),
avgLen = Some(8),
maxLen = Some(8)),
AttributeReference("cshort1", ShortType)() -> ColumnStat(min = Some(s3), max = Some(s4)),
@ -139,8 +139,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(
attrInt -> ColumnStat(min = Some(1), max = Some(6)),
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0)),
attrInt -> ColumnStat(min = Some(1), max = Some(6), nullCount = Some(2)),
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0), nullCount = Some(4)),
attrShort -> ColumnStat(min = Some(s1), max = Some(s4)),
attrLong -> ColumnStat(min = Some(1L), max = Some(6L)),
attrByte -> ColumnStat(min = Some(b1), max = Some(b4)),
@ -188,7 +188,58 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
val union = Union(Seq(child1, child2))
val expectedStats = logical.Statistics(sizeInBytes = 2 * 1024, rowCount = Some(4))
// Only null count is present in the attribute stats
val expectedStats = logical.Statistics(
sizeInBytes = 2 * 1024,
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(attrInt -> ColumnStat(nullCount = Some(0)))))
assert(union.stats === expectedStats)
}
test("col stats estimation when null count stats are not present for one child") {
val sz = Some(BigInt(1024))
val attrInt = AttributeReference("cint", IntegerType)()
val columnInfo = AttributeMap(
Seq(
attrInt -> ColumnStat(
distinctCount = Some(2),
min = Some(1),
max = Some(2),
nullCount = Some(2),
avgLen = Some(4),
maxLen = Some(4))))
// No null count
val columnInfo1 = AttributeMap(
Seq(
AttributeReference("cint1", IntegerType)() -> ColumnStat(
distinctCount = Some(2),
min = Some(3),
max = Some(4),
avgLen = Some(8),
maxLen = Some(8))))
val child1 = StatsTestPlan(
outputList = columnInfo.keys.toSeq,
rowCount = 2,
attributeStats = columnInfo,
size = sz)
val child2 = StatsTestPlan(
outputList = columnInfo1.keys.toSeq,
rowCount = 2,
attributeStats = columnInfo1,
size = sz)
val union = Union(Seq(child1, child2))
// Null count should not present in the stats.
val expectedStats = logical.Statistics(
sizeInBytes = 2 * 1024,
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(attrInt -> ColumnStat(min = Some(1), max = Some(4), nullCount = None))))
assert(union.stats === expectedStats)
}
}