[SPARK-17075][SQL][FOLLOWUP] fix some minor issues and clean up the code

## What changes were proposed in this pull request?

This is a follow-up of https://github.com/apache/spark/pull/16395. It fixes some code style issues, naming issues, some missing cases in pattern match, etc.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17065 from cloud-fan/follow-up.
This commit is contained in:
Wenchen Fan 2017-02-25 23:01:44 -08:00
parent 6ab60542e8
commit 89608cf262
5 changed files with 297 additions and 338 deletions

View file

@ -28,7 +28,7 @@ object AttributeMap {
}
}
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
extends Map[Attribute, A] with Serializable {
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)

View file

@ -17,9 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import java.sql.{Date, Timestamp}
import scala.collection.immutable.{HashSet, Map}
import scala.collection.immutable.HashSet
import scala.collection.mutable
import org.apache.spark.internal.Logging
@ -31,15 +29,16 @@ import org.apache.spark.sql.types._
case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging {
private val childStats = plan.child.stats(catalystConf)
/**
* We use a mutable colStats because we need to update the corresponding ColumnStat
* for a column after we apply a predicate condition. For example, column c has
* [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50),
* we need to set the column's [min, max] value to [40, 100] after we evaluate the
* first condition c > 40. We need to set the column's [min, max] value to [40, 50]
* We will update the corresponding ColumnStats for a column after we apply a predicate condition.
* For example, column c has [min, max] value as [0, 100]. In a range condition such as
* (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we
* evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50]
* after we evaluate the second condition c <= 50.
*/
private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty
private val colStatsMap = new ColumnStatsMap
/**
* Returns an option of Statistics for a Filter logical plan node.
@ -51,12 +50,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* @return Option[Statistics] When there is no statistics collected, it returns None.
*/
def estimate: Option[Statistics] = {
// We first copy child node's statistics and then modify it based on filter selectivity.
val stats: Statistics = plan.child.stats(catalystConf)
if (stats.rowCount.isEmpty) return None
if (childStats.rowCount.isEmpty) return None
// save a mutable copy of colStats so that we can later change it recursively
mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*)
colStatsMap.setInitValues(childStats.attributeStats)
// estimate selectivity of this filter predicate
val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
@ -65,22 +62,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case None => 1.0
}
// attributeStats has mapping Attribute-to-ColumnStat.
// mutableColStats has mapping ExprId-to-ColumnStat.
// We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat
val expridToAttrMap: Map[ExprId, Attribute] =
stats.attributeStats.map(kv => (kv._1.exprId, kv._1))
// copy mutableColStats contents to an immutable AttributeMap.
val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] =
mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2)
val newColStats = AttributeMap(mutableAttributeStats.toSeq)
val newColStats = colStatsMap.toColumnStats
val filteredRowCount: BigInt =
EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity)
val filteredSizeInBytes =
EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
val filteredSizeInBytes: BigInt =
EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats)
Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
attributeStats = newColStats))
}
@ -95,15 +84,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* @param condition the compound logical expression
* @param update a boolean flag to specify if we need to update ColumnStat of a column
* for subsequent conditions
* @return a double value to show the percentage of rows meeting a given condition.
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
condition match {
case And(cond1, cond2) =>
(calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update))
match {
// For ease of debugging, we compute percent1 and percent2 in 2 statements.
val percent1 = calculateFilterSelectivity(cond1, update)
val percent2 = calculateFilterSelectivity(cond2, update)
(percent1, percent2) match {
case (Some(p1), Some(p2)) => Some(p1 * p2)
case (Some(p1), None) => Some(p1)
case (None, Some(p2)) => Some(p2)
@ -127,8 +117,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case None => None
}
case _ =>
calculateSingleCondition(condition, update)
case _ => calculateSingleCondition(condition, update)
}
}
@ -140,7 +129,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* @param condition a single logical expression
* @param update a boolean flag to specify if we need to update ColumnStat of a column
* for subsequent conditions
* @return Option[Double] value to show the percentage of rows meeting a given condition.
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
@ -148,33 +137,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
// For evaluateBinary method, we assume the literal on the right side of an operator.
// So we will change the order if not.
// EqualTo does not care about the order
case op @ EqualTo(ar: AttributeReference, l: Literal) =>
evaluateBinary(op, ar, l, update)
case op @ EqualTo(l: Literal, ar: AttributeReference) =>
evaluateBinary(op, ar, l, update)
// EqualTo/EqualNullSafe does not care about the order
case op @ Equality(ar: Attribute, l: Literal) =>
evaluateEquality(ar, l, update)
case op @ Equality(l: Literal, ar: Attribute) =>
evaluateEquality(ar, l, update)
case op @ LessThan(ar: AttributeReference, l: Literal) =>
case op @ LessThan(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
case op @ LessThan(l: Literal, ar: AttributeReference) =>
case op @ LessThan(l: Literal, ar: Attribute) =>
evaluateBinary(GreaterThan(ar, l), ar, l, update)
case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
case op @ LessThanOrEqual(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) =>
case op @ LessThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update)
case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
case op @ GreaterThan(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
case op @ GreaterThan(l: Literal, ar: AttributeReference) =>
case op @ GreaterThan(l: Literal, ar: Attribute) =>
evaluateBinary(LessThan(ar, l), ar, l, update)
case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) =>
case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)
case In(ar: AttributeReference, expList)
case In(ar: Attribute, expList)
if expList.forall(e => e.isInstanceOf[Literal]) =>
// Expression [In (value, seq[Literal])] will be replaced with optimized version
// [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
@ -182,14 +171,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
val hSet = expList.map(e => e.eval())
evaluateInSet(ar, HashSet() ++ hSet, update)
case InSet(ar: AttributeReference, set) =>
case InSet(ar: Attribute, set) =>
evaluateInSet(ar, set, update)
case IsNull(ar: AttributeReference) =>
evaluateIsNull(ar, isNull = true, update)
case IsNull(ar: Attribute) =>
evaluateNullCheck(ar, isNull = true, update)
case IsNotNull(ar: AttributeReference) =>
evaluateIsNull(ar, isNull = false, update)
case IsNotNull(ar: Attribute) =>
evaluateNullCheck(ar, isNull = false, update)
case _ =>
// TODO: it's difficult to support string operators without advanced statistics.
@ -203,44 +192,43 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
/**
* Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition.
*
* @param attrRef an AttributeReference (or a column)
* @param attr an Attribute (or a column)
* @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
* @return an optional double value to show the percentage of rows meeting a given condition
* It returns None if no statistics collected for a given column.
*/
def evaluateIsNull(
attrRef: AttributeReference,
def evaluateNullCheck(
attr: Attribute,
isNull: Boolean,
update: Boolean)
: Option[Double] = {
if (!mutableColStats.contains(attrRef.exprId)) {
logDebug("[CBO] No statistics for " + attrRef)
update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val aColStat = mutableColStats(attrRef.exprId)
val rowCountValue = plan.child.stats(catalystConf).rowCount.get
val nullPercent: BigDecimal =
if (rowCountValue == 0) 0.0
else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue)
if (update) {
val newStats =
if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None)
else aColStat.copy(nullCount = 0)
mutableColStats += (attrRef.exprId -> newStats)
val colStat = colStatsMap(attr)
val rowCountValue = childStats.rowCount.get
val nullPercent: BigDecimal = if (rowCountValue == 0) {
0
} else {
BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)
}
val percent =
if (isNull) {
nullPercent.toDouble
}
else {
/** ISNOTNULL(column) */
1.0 - nullPercent.toDouble
if (update) {
val newStats = if (isNull) {
colStat.copy(distinctCount = 0, min = None, max = None)
} else {
colStat.copy(nullCount = 0)
}
colStatsMap(attr) = newStats
}
val percent = if (isNull) {
nullPercent.toDouble
} else {
1.0 - nullPercent.toDouble
}
Some(percent)
}
@ -249,7 +237,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting a binary comparison expression.
*
* @param op a binary comparison operator uch as =, <, <=, >, >=
* @param attrRef an AttributeReference (or a column)
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
@ -258,27 +246,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
*/
def evaluateBinary(
op: BinaryComparison,
attrRef: AttributeReference,
attr: Attribute,
literal: Literal,
update: Boolean)
: Option[Double] = {
if (!mutableColStats.contains(attrRef.exprId)) {
logDebug("[CBO] No statistics for " + attrRef)
return None
}
op match {
case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update)
update: Boolean): Option[Double] = {
attr.dataType match {
case _: NumericType | DateType | TimestampType =>
evaluateBinaryForNumeric(op, attr, literal, update)
case StringType | BinaryType =>
// TODO: It is difficult to support other binary comparisons for String/Binary
// type without min/max and advanced statistics like histogram.
logDebug("[CBO] No range comparison statistics for String/Binary type " + attr)
None
case _ =>
attrRef.dataType match {
case _: NumericType | DateType | TimestampType =>
evaluateBinaryForNumeric(op, attrRef, literal, update)
case StringType | BinaryType =>
// TODO: It is difficult to support other binary comparisons for String/Binary
// type without min/max and advanced statistics like histogram.
logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef)
None
}
// TODO: support boolean type.
None
}
}
@ -297,6 +278,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
case TimestampType =>
Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
case _: DecimalType =>
Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
case StringType | BinaryType =>
None
case _ =>
@ -308,37 +291,36 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting an equality (=) expression.
* This method evaluates the equality predicate for all data types.
*
* @param attrRef an AttributeReference (or a column)
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
* @return an optional double value to show the percentage of rows meeting a given condition
*/
def evaluateEqualTo(
attrRef: AttributeReference,
def evaluateEquality(
attr: Attribute,
literal: Literal,
update: Boolean)
: Option[Double] = {
val aColStat = mutableColStats(attrRef.exprId)
val ndv = aColStat.distinctCount
update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val colStat = colStatsMap(attr)
val ndv = colStat.distinctCount
// decide if the value is in [min, max] of the column.
// We currently don't store min/max for binary/string type.
// Hence, we assume it is in boundary for binary/string type.
val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType)
val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal)
if (inBoundary) {
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
if (statsRange.contains(literal)) {
if (update) {
// We update ColumnStat structure after apply this equality predicate.
// Set distinctCount to 1. Set nullCount to 0.
// Need to save new min/max using the external type value of the literal
val newValue = convertBoundValue(attrRef.dataType, literal.value)
val newStats = aColStat.copy(distinctCount = 1, min = newValue,
val newValue = convertBoundValue(attr.dataType, literal.value)
val newStats = colStat.copy(distinctCount = 1, min = newValue,
max = newValue, nullCount = 0)
mutableColStats += (attrRef.exprId -> newStats)
colStatsMap(attr) = newStats
}
Some(1.0 / ndv.toDouble)
@ -352,7 +334,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting "IN" operator expression.
* This method evaluates the equality predicate for all data types.
*
* @param attrRef an AttributeReference (or a column)
* @param attr an Attribute (or a column)
* @param hSet a set of literal values
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
@ -361,57 +343,52 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
*/
def evaluateInSet(
attrRef: AttributeReference,
attr: Attribute,
hSet: Set[Any],
update: Boolean)
: Option[Double] = {
if (!mutableColStats.contains(attrRef.exprId)) {
logDebug("[CBO] No statistics for " + attrRef)
update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val aColStat = mutableColStats(attrRef.exprId)
val ndv = aColStat.distinctCount
val aType = attrRef.dataType
var newNdv: Long = 0
val colStat = colStatsMap(attr)
val ndv = colStat.distinctCount
val dataType = attr.dataType
var newNdv = ndv
// use [min, max] to filter the original hSet
aType match {
case _: NumericType | DateType | TimestampType =>
val statsRange =
Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange]
// To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal.
// Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec.
val hSetBigdec = hSet.map(e => BigDecimal(e.toString))
val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max)
// We use hSetBigdecToAnyMap to help us find the original hSet value.
val hSetBigdecToAnyMap: Map[BigDecimal, Any] =
hSet.map(e => BigDecimal(e.toString) -> e).toMap
dataType match {
case _: NumericType | BooleanType | DateType | TimestampType =>
val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange]
val validQuerySet = hSet.filter { v =>
v != null && statsRange.contains(Literal(v, dataType))
}
if (validQuerySet.isEmpty) {
return Some(0.0)
}
// Need to save new min/max using the external type value of the literal
val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max))
val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min))
val newMax = convertBoundValue(
attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
val newMin = convertBoundValue(
attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
// newNdv should not be greater than the old ndv. For example, column has only 2 values
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
newNdv = math.min(validQuerySet.size.toLong, ndv.longValue())
newNdv = ndv.min(BigInt(validQuerySet.size))
if (update) {
val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
max = newMax, nullCount = 0)
mutableColStats += (attrRef.exprId -> newStats)
colStatsMap(attr) = newStats
}
// We assume the whole set since there is no min/max information for String/Binary type
case StringType | BinaryType =>
newNdv = math.min(hSet.size.toLong, ndv.longValue())
newNdv = ndv.min(BigInt(hSet.size))
if (update) {
val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0)
mutableColStats += (attrRef.exprId -> newStats)
val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0)
colStatsMap(attr) = newStats
}
}
@ -425,7 +402,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* This method evaluate expression for Numeric columns only.
*
* @param op a binary comparison operator uch as =, <, <=, >, >=
* @param attrRef an AttributeReference (or a column)
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
@ -433,16 +410,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
*/
def evaluateBinaryForNumeric(
op: BinaryComparison,
attrRef: AttributeReference,
attr: Attribute,
literal: Literal,
update: Boolean)
: Option[Double] = {
update: Boolean): Option[Double] = {
var percent = 1.0
val aColStat = mutableColStats(attrRef.exprId)
val ndv = aColStat.distinctCount
val colStat = colStatsMap(attr)
val statsRange =
Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange]
Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
// determine the overlapping degree between predicate range and column's range
val literalValueBD = BigDecimal(literal.value.toString)
@ -463,33 +438,37 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
percent = 1.0
} else {
// this is partial overlap case
var newMax = aColStat.max
var newMin = aColStat.min
var newNdv = ndv
val literalToDouble = literalValueBD.toDouble
val maxToDouble = BigDecimal(statsRange.max).toDouble
val minToDouble = BigDecimal(statsRange.min).toDouble
val literalDouble = literalValueBD.toDouble
val maxDouble = BigDecimal(statsRange.max).toDouble
val minDouble = BigDecimal(statsRange.min).toDouble
// Without advanced statistics like histogram, we assume uniform data distribution.
// We just prorate the adjusted range over the initial range to compute filter selectivity.
// For ease of computation, we convert all relevant numeric values to Double.
percent = op match {
case _: LessThan =>
(literalToDouble - minToDouble) / (maxToDouble - minToDouble)
(literalDouble - minDouble) / (maxDouble - minDouble)
case _: LessThanOrEqual =>
if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble
else (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
if (literalValueBD == BigDecimal(statsRange.min)) {
1.0 / colStat.distinctCount.toDouble
} else {
(literalDouble - minDouble) / (maxDouble - minDouble)
}
case _: GreaterThan =>
(maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
(maxDouble - literalDouble) / (maxDouble - minDouble)
case _: GreaterThanOrEqual =>
if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble
else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
if (literalValueBD == BigDecimal(statsRange.max)) {
1.0 / colStat.distinctCount.toDouble
} else {
(maxDouble - literalDouble) / (maxDouble - minDouble)
}
}
// Need to save new min/max using the external type value of the literal
val newValue = convertBoundValue(attrRef.dataType, literal.value)
if (update) {
// Need to save new min/max using the external type value of the literal
val newValue = convertBoundValue(attr.dataType, literal.value)
var newMax = colStat.max
var newMin = colStat.min
op match {
case _: GreaterThan => newMin = newValue
case _: GreaterThanOrEqual => newMin = newValue
@ -497,11 +476,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case _: LessThanOrEqual => newMax = newValue
}
newNdv = math.max(math.round(ndv.toDouble * percent), 1)
val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1)
val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
max = newMax, nullCount = 0)
mutableColStats += (attrRef.exprId -> newStats)
colStatsMap(attr) = newStats
}
}
@ -509,3 +488,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
}
}
class ColumnStatsMap {
private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty
def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = {
baseMap.clear()
baseMap ++= colStats.baseMap
}
def contains(a: Attribute): Boolean = baseMap.contains(a.exprId)
def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2
def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats)
def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq)
}

View file

@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
case _ if !rowCountsExist(conf, join.left, join.right) =>
None
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
// 1. Compute join selectivity
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
val selectivity = joinSelectivity(joinKeyPairs)
@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) {
// The output is empty, we don't need to keep column stats.
Nil
} else if (innerJoinedRows == 0) {
} else if (selectivity == 0) {
joinType match {
// For outer joins, if the inner join part is empty, the number of output rows is the
// For outer joins, if the join selectivity is 0, the number of output rows is the
// same as that of the outer side. And column stats of join keys from the outer side
// keep unchanged, while column stats of join keys from the other side should be updated
// based on added null values.
@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
}
case _ => Nil
}
} else if (selectivity == 1) {
// Cartesian product, just propagate the original column stats
inputAttrStats.toSeq
} else {
val joinKeyStats = getIntersectedStats(joinKeyPairs)
join.joinType match {
@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
Some(Statistics(
sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
isBroadcastable = false))
attributeStats = outputAttrStats))
case _ =>
// When there is no equi-join condition, we do estimation like cartesian product.
@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
Some(Statistics(
sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats),
rowCount = Some(outputRows),
attributeStats = inputAttrStats,
isBroadcastable = false))
attributeStats = inputAttrStats))
}
// scalastyle:off
@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
}
if (ndvDenom < 0) {
// There isn't join keys or column stats for any of the join key pairs, we do estimation like
// cartesian product.
// We can't find any join key pairs with column stats, estimate it as cartesian join.
1
} else if (ndvDenom == 0) {
// One of the join key pairs is disjoint, thus the two sides of join is disjoint.
@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
/**
* Propagate or update column stats for output attributes.
* 1. For cartesian product, all values are preserved, so there's no need to change column stats.
* 2. For other cases, a) update max/min of join keys based on their intersected range. b) update
* distinct count of other attributes based on output rows after join.
*/
private def updateAttrStats(
outputRows: BigInt,
@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
val leftRows = leftStats.rowCount.get
val rightRows = rightStats.rowCount.get
if (outputRows == leftRows * rightRows) {
// Cartesian product, just propagate the original column stats
attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a))
} else {
val leftRatio =
if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0)
val rightRatio =
if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0)
attributes.foreach { a =>
// check if this attribute is a join key
if (joinKeyStats.contains(a)) {
outputAttrStats += a -> joinKeyStats(a)
attributes.foreach { a =>
// check if this attribute is a join key
if (joinKeyStats.contains(a)) {
outputAttrStats += a -> joinKeyStats(a)
} else {
val leftRatio = if (leftRows != 0) {
BigDecimal(outputRows) / BigDecimal(leftRows)
} else {
val oldColStat = oldAttrStats(a)
val oldNdv = oldColStat.distinctCount
// We only change (scale down) the number of distinct values if the number of rows
// decreases after join, because join won't produce new values even if the number of
// rows increases.
val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
ceil(BigDecimal(oldNdv) * leftRatio)
} else if (join.right.outputSet.contains(a) && rightRatio < 1) {
ceil(BigDecimal(oldNdv) * rightRatio)
} else {
oldNdv
}
// TODO: support nullCount updates for specific outer joins
outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
BigDecimal(0)
}
val rightRatio = if (rightRows != 0) {
BigDecimal(outputRows) / BigDecimal(rightRows)
} else {
BigDecimal(0)
}
val oldColStat = oldAttrStats(a)
val oldNdv = oldColStat.distinctCount
// We only change (scale down) the number of distinct values if the number of rows
// decreases after join, because join won't produce new values even if the number of
// rows increases.
val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
ceil(BigDecimal(oldNdv) * leftRatio)
} else if (join.right.outputSet.contains(a) && rightRatio < 1) {
ceil(BigDecimal(oldNdv) * rightRatio)
} else {
oldNdv
}
// TODO: support nullCount updates for specific outer joins
outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
}
}
outputAttrStats
}
@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
// Update intersected column stats
assert(leftKey.dataType.sameType(rightKey.dataType))
val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType)
intersectedStats.put(leftKey,
leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0))
intersectedStats.put(rightKey,
rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0))
val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen)
val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2
val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
intersectedStats.put(leftKey, newStats)
intersectedStats.put(rightKey, newStats)
}
AttributeMap(intersectedStats.toSeq)
}
@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) {
Some(Statistics(
sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats),
rowCount = Some(outputRows),
attributeStats = leftStats.attributeStats,
isBroadcastable = false))
attributeStats = leftStats.attributeStats))
} else {
None
}

View file

@ -26,19 +26,33 @@ import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
/** Value range of a column. */
trait Range
trait Range {
def contains(l: Literal): Boolean
}
/** For simplicity we use decimal to unify operations of numeric ranges. */
case class NumericRange(min: JDecimal, max: JDecimal) extends Range
case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
override def contains(l: Literal): Boolean = {
val decimal = l.dataType match {
case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
case _ => new JDecimal(l.value.toString)
}
min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
}
}
/**
* This version of Spark does not have min/max for binary/string types, we define their default
* behaviors by this class.
*/
class DefaultRange extends Range
class DefaultRange extends Range {
override def contains(l: Literal): Boolean = true
}
/** This is for columns with only null values. */
class NullRange extends Range
class NullRange extends Range {
override def contains(l: Literal): Boolean = false
}
object Range {
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
@ -58,20 +72,6 @@ object Range {
n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0
}
def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match {
case _: DefaultRange => true
case _: NullRange => false
case n: NumericRange =>
val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) {
if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
} else {
assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] ||
lit.dataType.isInstanceOf[TimestampType])
new JDecimal(lit.value.toString)
}
n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0
}
/**
* Intersected results of two ranges. This is only for two overlapped ranges.
* The outputs are the intersected min/max values.

View file

@ -17,12 +17,11 @@
package org.apache.spark.sql.catalyst.statsEstimation
import java.sql.{Date, Timestamp}
import java.sql.Date
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
/**
@ -38,6 +37,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
// only 2 values
val arBool = AttributeReference("cbool", BooleanType)()
val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1)
// Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
val dMin = Date.valueOf("2017-01-01")
val dMax = Date.valueOf("2017-01-10")
@ -45,14 +49,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
// Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through
// "2017-01-01 10:00:00" for 10 distinct timestamps (or hours).
val tsMin = Timestamp.valueOf("2017-01-01 01:00:00")
val tsMax = Timestamp.valueOf("2017-01-01 10:00:00")
val arTimestamp = AttributeReference("ctimestamp", TimestampType)()
val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax),
nullCount = 0, avgLen = 8, maxLen = 8)
// Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
val decMin = new java.math.BigDecimal("0.200000000000000000")
val decMax = new java.math.BigDecimal("0.800000000000000000")
@ -77,8 +73,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(1L)
)
1)
}
test("cint <=> 2") {
validateEstimatedStats(
arInt,
Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
nullCount = 0, avgLen = 4, maxLen = 4),
1)
}
test("cint = 0") {
@ -88,8 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(0L)
)
0)
}
test("cint < 3") {
@ -98,8 +101,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(3L)
)
3)
}
test("cint < 0") {
@ -109,8 +111,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(0L)
)
0)
}
test("cint <= 3") {
@ -119,8 +120,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(3L)
)
3)
}
test("cint > 6") {
@ -129,8 +129,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(5L)
)
5)
}
test("cint > 10") {
@ -140,8 +139,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(0L)
)
0)
}
test("cint >= 6") {
@ -150,8 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(5L)
)
5)
}
test("cint IS NULL") {
@ -160,8 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 0, min = None, max = None,
nullCount = 0, avgLen = 4, maxLen = 4),
Some(0L)
)
0)
}
test("cint IS NOT NULL") {
@ -170,8 +166,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(10L)
)
10)
}
test("cint > 3 AND cint <= 6") {
@ -181,8 +176,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(4L)
)
4)
}
test("cint = 3 OR cint = 6") {
@ -192,8 +186,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(2L)
)
2)
}
test("cint IN (3, 4, 5)") {
@ -202,8 +195,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(3L)
)
3)
}
test("cint NOT IN (3, 4, 5)") {
@ -212,8 +204,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(7L)
)
7)
}
test("cbool = true") {
validateEstimatedStats(
arBool,
Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)),
ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1),
5)
}
test("cbool > false") {
// bool comparison is not supported yet, so stats remain same.
validateEstimatedStats(
arBool,
Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)),
ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1),
10)
}
test("cdate = cast('2017-01-02' AS DATE)") {
@ -224,8 +234,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDate), 10L)),
ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(1L)
)
1)
}
test("cdate < cast('2017-01-03' AS DATE)") {
@ -236,8 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDate), 10L)),
ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(3L)
)
3)
}
test("""cdate IN ( cast('2017-01-03' AS DATE),
@ -251,32 +259,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDate), 10L)),
ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(3L)
)
}
test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") {
val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00")
validateEstimatedStats(
arTimestamp,
Filter(EqualTo(arTimestamp, Literal(ts2017010102)),
childStatsTestPlan(Seq(arTimestamp), 10L)),
ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102),
nullCount = 0, avgLen = 8, maxLen = 8),
Some(1L)
)
}
test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") {
val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00")
validateEstimatedStats(
arTimestamp,
Filter(LessThan(arTimestamp, Literal(ts2017010103)),
childStatsTestPlan(Seq(arTimestamp), 10L)),
ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103),
nullCount = 0, avgLen = 8, maxLen = 8),
Some(3L)
)
3)
}
test("cdecimal = 0.400000000000000000") {
@ -287,20 +270,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDecimal), 4L)),
ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
nullCount = 0, avgLen = 8, maxLen = 8),
Some(1L)
)
1)
}
test("cdecimal < 0.60 ") {
val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
validateEstimatedStats(
arDecimal,
Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))),
Filter(LessThan(arDecimal, Literal(dec_0_60)),
childStatsTestPlan(Seq(arDecimal), 4L)),
ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
nullCount = 0, avgLen = 8, maxLen = 8),
Some(3L)
)
3)
}
test("cdouble < 3.0") {
@ -309,8 +290,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
nullCount = 0, avgLen = 8, maxLen = 8),
Some(3L)
)
3)
}
test("cstring = 'A2'") {
@ -319,8 +299,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
ColumnStat(distinctCount = 1, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2),
Some(1L)
)
1)
}
// There is no min/max statistics for String type. We estimate 10 rows returned.
@ -330,8 +309,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
ColumnStat(distinctCount = 10, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2),
Some(10L)
)
10)
}
// This is a corner test case. We want to test if we can handle the case when the number of
@ -351,8 +329,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
nullCount = 0, avgLen = 4, maxLen = 4),
Some(2L)
)
2)
}
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
@ -361,8 +338,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
rowCount = tableRowCount,
attributeStats = AttributeMap(Seq(
arInt -> childColStatInt,
arBool -> childColStatBool,
arDate -> childColStatDate,
arTimestamp -> childColStatTimestamp,
arDecimal -> childColStatDecimal,
arDouble -> childColStatDouble,
arString -> childColStatString
@ -374,46 +351,31 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
ar: AttributeReference,
filterNode: Filter,
expectedColStats: ColumnStat,
rowCount: Option[BigInt] = None)
: Unit = {
rowCount: Int): Unit = {
val expectedRowCount: BigInt = rowCount.getOrElse(0L)
val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats)
val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats)
val filteredStats = filterNode.stats(conf)
assert(filteredStats.sizeInBytes == expectedSizeInBytes)
assert(filteredStats.rowCount == rowCount)
ar.dataType match {
case DecimalType() =>
// Due to the internal transformation for DecimalType within engine, the new min/max
// in ColumnStat may have a different structure even it contains the right values.
// We convert them to Java BigDecimal values so that we can compare the entire object.
val generatedColumnStats = filteredStats.attributeStats(ar)
val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString)
val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString)
val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax))
assert(outputColStats == expectedColStats)
case _ =>
// For all other SQL types, we compare the entire object directly.
assert(filteredStats.attributeStats(ar) == expectedColStats)
}
assert(filteredStats.rowCount.get == rowCount)
assert(filteredStats.attributeStats(ar) == expectedColStats)
// If the filter has a binary operator (including those nested inside
// AND/OR/NOT), swap the sides of the attribte and the literal, reverse the
// operator, and then check again.
val rewrittenFilter = filterNode transformExpressionsDown {
case op @ EqualTo(ar: AttributeReference, l: Literal) =>
case EqualTo(ar: AttributeReference, l: Literal) =>
EqualTo(l, ar)
case op @ LessThan(ar: AttributeReference, l: Literal) =>
case LessThan(ar: AttributeReference, l: Literal) =>
GreaterThan(l, ar)
case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
case LessThanOrEqual(ar: AttributeReference, l: Literal) =>
GreaterThanOrEqual(l, ar)
case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
case GreaterThan(ar: AttributeReference, l: Literal) =>
LessThan(l, ar)
case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
case GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
LessThanOrEqual(l, ar)
}