[SPARK-17075][SQL][FOLLOWUP] fix some minor issues and clean up the code
## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/16395. It fixes some code style issues, naming issues, some missing cases in pattern match, etc. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #17065 from cloud-fan/follow-up.
This commit is contained in:
parent
6ab60542e8
commit
89608cf262
|
@ -28,7 +28,7 @@ object AttributeMap {
|
|||
}
|
||||
}
|
||||
|
||||
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
|
||||
class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
|
||||
extends Map[Attribute, A] with Serializable {
|
||||
|
||||
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
|
||||
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import scala.collection.immutable.{HashSet, Map}
|
||||
import scala.collection.immutable.HashSet
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
|
@ -31,15 +29,16 @@ import org.apache.spark.sql.types._
|
|||
|
||||
case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging {
|
||||
|
||||
private val childStats = plan.child.stats(catalystConf)
|
||||
|
||||
/**
|
||||
* We use a mutable colStats because we need to update the corresponding ColumnStat
|
||||
* for a column after we apply a predicate condition. For example, column c has
|
||||
* [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50),
|
||||
* we need to set the column's [min, max] value to [40, 100] after we evaluate the
|
||||
* first condition c > 40. We need to set the column's [min, max] value to [40, 50]
|
||||
* We will update the corresponding ColumnStats for a column after we apply a predicate condition.
|
||||
* For example, column c has [min, max] value as [0, 100]. In a range condition such as
|
||||
* (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we
|
||||
* evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50]
|
||||
* after we evaluate the second condition c <= 50.
|
||||
*/
|
||||
private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty
|
||||
private val colStatsMap = new ColumnStatsMap
|
||||
|
||||
/**
|
||||
* Returns an option of Statistics for a Filter logical plan node.
|
||||
|
@ -51,12 +50,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* @return Option[Statistics] When there is no statistics collected, it returns None.
|
||||
*/
|
||||
def estimate: Option[Statistics] = {
|
||||
// We first copy child node's statistics and then modify it based on filter selectivity.
|
||||
val stats: Statistics = plan.child.stats(catalystConf)
|
||||
if (stats.rowCount.isEmpty) return None
|
||||
if (childStats.rowCount.isEmpty) return None
|
||||
|
||||
// save a mutable copy of colStats so that we can later change it recursively
|
||||
mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*)
|
||||
colStatsMap.setInitValues(childStats.attributeStats)
|
||||
|
||||
// estimate selectivity of this filter predicate
|
||||
val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
|
||||
|
@ -65,22 +62,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
case None => 1.0
|
||||
}
|
||||
|
||||
// attributeStats has mapping Attribute-to-ColumnStat.
|
||||
// mutableColStats has mapping ExprId-to-ColumnStat.
|
||||
// We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat
|
||||
val expridToAttrMap: Map[ExprId, Attribute] =
|
||||
stats.attributeStats.map(kv => (kv._1.exprId, kv._1))
|
||||
// copy mutableColStats contents to an immutable AttributeMap.
|
||||
val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] =
|
||||
mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2)
|
||||
val newColStats = AttributeMap(mutableAttributeStats.toSeq)
|
||||
val newColStats = colStatsMap.toColumnStats
|
||||
|
||||
val filteredRowCount: BigInt =
|
||||
EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity)
|
||||
val filteredSizeInBytes =
|
||||
EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
|
||||
val filteredSizeInBytes: BigInt =
|
||||
EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats)
|
||||
|
||||
Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
|
||||
Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
|
||||
attributeStats = newColStats))
|
||||
}
|
||||
|
||||
|
@ -95,15 +84,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* @param condition the compound logical expression
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a column
|
||||
* for subsequent conditions
|
||||
* @return a double value to show the percentage of rows meeting a given condition.
|
||||
* @return an optional double value to show the percentage of rows meeting a given condition.
|
||||
* It returns None if the condition is not supported.
|
||||
*/
|
||||
def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
|
||||
|
||||
condition match {
|
||||
case And(cond1, cond2) =>
|
||||
(calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update))
|
||||
match {
|
||||
// For ease of debugging, we compute percent1 and percent2 in 2 statements.
|
||||
val percent1 = calculateFilterSelectivity(cond1, update)
|
||||
val percent2 = calculateFilterSelectivity(cond2, update)
|
||||
(percent1, percent2) match {
|
||||
case (Some(p1), Some(p2)) => Some(p1 * p2)
|
||||
case (Some(p1), None) => Some(p1)
|
||||
case (None, Some(p2)) => Some(p2)
|
||||
|
@ -127,8 +117,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
case None => None
|
||||
}
|
||||
|
||||
case _ =>
|
||||
calculateSingleCondition(condition, update)
|
||||
case _ => calculateSingleCondition(condition, update)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,7 +129,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* @param condition a single logical expression
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a column
|
||||
* for subsequent conditions
|
||||
* @return Option[Double] value to show the percentage of rows meeting a given condition.
|
||||
* @return an optional double value to show the percentage of rows meeting a given condition.
|
||||
* It returns None if the condition is not supported.
|
||||
*/
|
||||
def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
|
||||
|
@ -148,33 +137,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
// For evaluateBinary method, we assume the literal on the right side of an operator.
|
||||
// So we will change the order if not.
|
||||
|
||||
// EqualTo does not care about the order
|
||||
case op @ EqualTo(ar: AttributeReference, l: Literal) =>
|
||||
evaluateBinary(op, ar, l, update)
|
||||
case op @ EqualTo(l: Literal, ar: AttributeReference) =>
|
||||
evaluateBinary(op, ar, l, update)
|
||||
// EqualTo/EqualNullSafe does not care about the order
|
||||
case op @ Equality(ar: Attribute, l: Literal) =>
|
||||
evaluateEquality(ar, l, update)
|
||||
case op @ Equality(l: Literal, ar: Attribute) =>
|
||||
evaluateEquality(ar, l, update)
|
||||
|
||||
case op @ LessThan(ar: AttributeReference, l: Literal) =>
|
||||
case op @ LessThan(ar: Attribute, l: Literal) =>
|
||||
evaluateBinary(op, ar, l, update)
|
||||
case op @ LessThan(l: Literal, ar: AttributeReference) =>
|
||||
case op @ LessThan(l: Literal, ar: Attribute) =>
|
||||
evaluateBinary(GreaterThan(ar, l), ar, l, update)
|
||||
|
||||
case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
|
||||
case op @ LessThanOrEqual(ar: Attribute, l: Literal) =>
|
||||
evaluateBinary(op, ar, l, update)
|
||||
case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) =>
|
||||
case op @ LessThanOrEqual(l: Literal, ar: Attribute) =>
|
||||
evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update)
|
||||
|
||||
case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
|
||||
case op @ GreaterThan(ar: Attribute, l: Literal) =>
|
||||
evaluateBinary(op, ar, l, update)
|
||||
case op @ GreaterThan(l: Literal, ar: AttributeReference) =>
|
||||
case op @ GreaterThan(l: Literal, ar: Attribute) =>
|
||||
evaluateBinary(LessThan(ar, l), ar, l, update)
|
||||
|
||||
case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
|
||||
case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) =>
|
||||
evaluateBinary(op, ar, l, update)
|
||||
case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) =>
|
||||
case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) =>
|
||||
evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)
|
||||
|
||||
case In(ar: AttributeReference, expList)
|
||||
case In(ar: Attribute, expList)
|
||||
if expList.forall(e => e.isInstanceOf[Literal]) =>
|
||||
// Expression [In (value, seq[Literal])] will be replaced with optimized version
|
||||
// [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
|
||||
|
@ -182,14 +171,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
val hSet = expList.map(e => e.eval())
|
||||
evaluateInSet(ar, HashSet() ++ hSet, update)
|
||||
|
||||
case InSet(ar: AttributeReference, set) =>
|
||||
case InSet(ar: Attribute, set) =>
|
||||
evaluateInSet(ar, set, update)
|
||||
|
||||
case IsNull(ar: AttributeReference) =>
|
||||
evaluateIsNull(ar, isNull = true, update)
|
||||
case IsNull(ar: Attribute) =>
|
||||
evaluateNullCheck(ar, isNull = true, update)
|
||||
|
||||
case IsNotNull(ar: AttributeReference) =>
|
||||
evaluateIsNull(ar, isNull = false, update)
|
||||
case IsNotNull(ar: Attribute) =>
|
||||
evaluateNullCheck(ar, isNull = false, update)
|
||||
|
||||
case _ =>
|
||||
// TODO: it's difficult to support string operators without advanced statistics.
|
||||
|
@ -203,44 +192,43 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
/**
|
||||
* Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition.
|
||||
*
|
||||
* @param attrRef an AttributeReference (or a column)
|
||||
* @param attr an Attribute (or a column)
|
||||
* @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
|
||||
* for subsequent conditions
|
||||
* @return an optional double value to show the percentage of rows meeting a given condition
|
||||
* It returns None if no statistics collected for a given column.
|
||||
*/
|
||||
def evaluateIsNull(
|
||||
attrRef: AttributeReference,
|
||||
def evaluateNullCheck(
|
||||
attr: Attribute,
|
||||
isNull: Boolean,
|
||||
update: Boolean)
|
||||
: Option[Double] = {
|
||||
if (!mutableColStats.contains(attrRef.exprId)) {
|
||||
logDebug("[CBO] No statistics for " + attrRef)
|
||||
update: Boolean): Option[Double] = {
|
||||
if (!colStatsMap.contains(attr)) {
|
||||
logDebug("[CBO] No statistics for " + attr)
|
||||
return None
|
||||
}
|
||||
val aColStat = mutableColStats(attrRef.exprId)
|
||||
val rowCountValue = plan.child.stats(catalystConf).rowCount.get
|
||||
val nullPercent: BigDecimal =
|
||||
if (rowCountValue == 0) 0.0
|
||||
else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue)
|
||||
|
||||
if (update) {
|
||||
val newStats =
|
||||
if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None)
|
||||
else aColStat.copy(nullCount = 0)
|
||||
|
||||
mutableColStats += (attrRef.exprId -> newStats)
|
||||
val colStat = colStatsMap(attr)
|
||||
val rowCountValue = childStats.rowCount.get
|
||||
val nullPercent: BigDecimal = if (rowCountValue == 0) {
|
||||
0
|
||||
} else {
|
||||
BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)
|
||||
}
|
||||
|
||||
val percent =
|
||||
if (isNull) {
|
||||
nullPercent.toDouble
|
||||
}
|
||||
else {
|
||||
/** ISNOTNULL(column) */
|
||||
1.0 - nullPercent.toDouble
|
||||
if (update) {
|
||||
val newStats = if (isNull) {
|
||||
colStat.copy(distinctCount = 0, min = None, max = None)
|
||||
} else {
|
||||
colStat.copy(nullCount = 0)
|
||||
}
|
||||
colStatsMap(attr) = newStats
|
||||
}
|
||||
|
||||
val percent = if (isNull) {
|
||||
nullPercent.toDouble
|
||||
} else {
|
||||
1.0 - nullPercent.toDouble
|
||||
}
|
||||
|
||||
Some(percent)
|
||||
}
|
||||
|
@ -249,7 +237,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* Returns a percentage of rows meeting a binary comparison expression.
|
||||
*
|
||||
* @param op a binary comparison operator uch as =, <, <=, >, >=
|
||||
* @param attrRef an AttributeReference (or a column)
|
||||
* @param attr an Attribute (or a column)
|
||||
* @param literal a literal value (or constant)
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
|
||||
* for subsequent conditions
|
||||
|
@ -258,27 +246,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
*/
|
||||
def evaluateBinary(
|
||||
op: BinaryComparison,
|
||||
attrRef: AttributeReference,
|
||||
attr: Attribute,
|
||||
literal: Literal,
|
||||
update: Boolean)
|
||||
: Option[Double] = {
|
||||
if (!mutableColStats.contains(attrRef.exprId)) {
|
||||
logDebug("[CBO] No statistics for " + attrRef)
|
||||
return None
|
||||
}
|
||||
|
||||
op match {
|
||||
case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update)
|
||||
update: Boolean): Option[Double] = {
|
||||
attr.dataType match {
|
||||
case _: NumericType | DateType | TimestampType =>
|
||||
evaluateBinaryForNumeric(op, attr, literal, update)
|
||||
case StringType | BinaryType =>
|
||||
// TODO: It is difficult to support other binary comparisons for String/Binary
|
||||
// type without min/max and advanced statistics like histogram.
|
||||
logDebug("[CBO] No range comparison statistics for String/Binary type " + attr)
|
||||
None
|
||||
case _ =>
|
||||
attrRef.dataType match {
|
||||
case _: NumericType | DateType | TimestampType =>
|
||||
evaluateBinaryForNumeric(op, attrRef, literal, update)
|
||||
case StringType | BinaryType =>
|
||||
// TODO: It is difficult to support other binary comparisons for String/Binary
|
||||
// type without min/max and advanced statistics like histogram.
|
||||
logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef)
|
||||
None
|
||||
}
|
||||
// TODO: support boolean type.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -297,6 +278,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
|
||||
case TimestampType =>
|
||||
Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
|
||||
case _: DecimalType =>
|
||||
Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
|
||||
case StringType | BinaryType =>
|
||||
None
|
||||
case _ =>
|
||||
|
@ -308,37 +291,36 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* Returns a percentage of rows meeting an equality (=) expression.
|
||||
* This method evaluates the equality predicate for all data types.
|
||||
*
|
||||
* @param attrRef an AttributeReference (or a column)
|
||||
* @param attr an Attribute (or a column)
|
||||
* @param literal a literal value (or constant)
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
|
||||
* for subsequent conditions
|
||||
* @return an optional double value to show the percentage of rows meeting a given condition
|
||||
*/
|
||||
def evaluateEqualTo(
|
||||
attrRef: AttributeReference,
|
||||
def evaluateEquality(
|
||||
attr: Attribute,
|
||||
literal: Literal,
|
||||
update: Boolean)
|
||||
: Option[Double] = {
|
||||
|
||||
val aColStat = mutableColStats(attrRef.exprId)
|
||||
val ndv = aColStat.distinctCount
|
||||
update: Boolean): Option[Double] = {
|
||||
if (!colStatsMap.contains(attr)) {
|
||||
logDebug("[CBO] No statistics for " + attr)
|
||||
return None
|
||||
}
|
||||
val colStat = colStatsMap(attr)
|
||||
val ndv = colStat.distinctCount
|
||||
|
||||
// decide if the value is in [min, max] of the column.
|
||||
// We currently don't store min/max for binary/string type.
|
||||
// Hence, we assume it is in boundary for binary/string type.
|
||||
val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType)
|
||||
val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal)
|
||||
|
||||
if (inBoundary) {
|
||||
|
||||
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
|
||||
if (statsRange.contains(literal)) {
|
||||
if (update) {
|
||||
// We update ColumnStat structure after apply this equality predicate.
|
||||
// Set distinctCount to 1. Set nullCount to 0.
|
||||
// Need to save new min/max using the external type value of the literal
|
||||
val newValue = convertBoundValue(attrRef.dataType, literal.value)
|
||||
val newStats = aColStat.copy(distinctCount = 1, min = newValue,
|
||||
val newValue = convertBoundValue(attr.dataType, literal.value)
|
||||
val newStats = colStat.copy(distinctCount = 1, min = newValue,
|
||||
max = newValue, nullCount = 0)
|
||||
mutableColStats += (attrRef.exprId -> newStats)
|
||||
colStatsMap(attr) = newStats
|
||||
}
|
||||
|
||||
Some(1.0 / ndv.toDouble)
|
||||
|
@ -352,7 +334,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* Returns a percentage of rows meeting "IN" operator expression.
|
||||
* This method evaluates the equality predicate for all data types.
|
||||
*
|
||||
* @param attrRef an AttributeReference (or a column)
|
||||
* @param attr an Attribute (or a column)
|
||||
* @param hSet a set of literal values
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
|
||||
* for subsequent conditions
|
||||
|
@ -361,57 +343,52 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
*/
|
||||
|
||||
def evaluateInSet(
|
||||
attrRef: AttributeReference,
|
||||
attr: Attribute,
|
||||
hSet: Set[Any],
|
||||
update: Boolean)
|
||||
: Option[Double] = {
|
||||
if (!mutableColStats.contains(attrRef.exprId)) {
|
||||
logDebug("[CBO] No statistics for " + attrRef)
|
||||
update: Boolean): Option[Double] = {
|
||||
if (!colStatsMap.contains(attr)) {
|
||||
logDebug("[CBO] No statistics for " + attr)
|
||||
return None
|
||||
}
|
||||
|
||||
val aColStat = mutableColStats(attrRef.exprId)
|
||||
val ndv = aColStat.distinctCount
|
||||
val aType = attrRef.dataType
|
||||
var newNdv: Long = 0
|
||||
val colStat = colStatsMap(attr)
|
||||
val ndv = colStat.distinctCount
|
||||
val dataType = attr.dataType
|
||||
var newNdv = ndv
|
||||
|
||||
// use [min, max] to filter the original hSet
|
||||
aType match {
|
||||
case _: NumericType | DateType | TimestampType =>
|
||||
val statsRange =
|
||||
Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange]
|
||||
|
||||
// To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal.
|
||||
// Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec.
|
||||
val hSetBigdec = hSet.map(e => BigDecimal(e.toString))
|
||||
val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max)
|
||||
// We use hSetBigdecToAnyMap to help us find the original hSet value.
|
||||
val hSetBigdecToAnyMap: Map[BigDecimal, Any] =
|
||||
hSet.map(e => BigDecimal(e.toString) -> e).toMap
|
||||
dataType match {
|
||||
case _: NumericType | BooleanType | DateType | TimestampType =>
|
||||
val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange]
|
||||
val validQuerySet = hSet.filter { v =>
|
||||
v != null && statsRange.contains(Literal(v, dataType))
|
||||
}
|
||||
|
||||
if (validQuerySet.isEmpty) {
|
||||
return Some(0.0)
|
||||
}
|
||||
|
||||
// Need to save new min/max using the external type value of the literal
|
||||
val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max))
|
||||
val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min))
|
||||
val newMax = convertBoundValue(
|
||||
attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
|
||||
val newMin = convertBoundValue(
|
||||
attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
|
||||
|
||||
// newNdv should not be greater than the old ndv. For example, column has only 2 values
|
||||
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
|
||||
newNdv = math.min(validQuerySet.size.toLong, ndv.longValue())
|
||||
newNdv = ndv.min(BigInt(validQuerySet.size))
|
||||
if (update) {
|
||||
val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
|
||||
val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
|
||||
max = newMax, nullCount = 0)
|
||||
mutableColStats += (attrRef.exprId -> newStats)
|
||||
colStatsMap(attr) = newStats
|
||||
}
|
||||
|
||||
// We assume the whole set since there is no min/max information for String/Binary type
|
||||
case StringType | BinaryType =>
|
||||
newNdv = math.min(hSet.size.toLong, ndv.longValue())
|
||||
newNdv = ndv.min(BigInt(hSet.size))
|
||||
if (update) {
|
||||
val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0)
|
||||
mutableColStats += (attrRef.exprId -> newStats)
|
||||
val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0)
|
||||
colStatsMap(attr) = newStats
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -425,7 +402,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
* This method evaluate expression for Numeric columns only.
|
||||
*
|
||||
* @param op a binary comparison operator uch as =, <, <=, >, >=
|
||||
* @param attrRef an AttributeReference (or a column)
|
||||
* @param attr an Attribute (or a column)
|
||||
* @param literal a literal value (or constant)
|
||||
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
|
||||
* for subsequent conditions
|
||||
|
@ -433,16 +410,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
*/
|
||||
def evaluateBinaryForNumeric(
|
||||
op: BinaryComparison,
|
||||
attrRef: AttributeReference,
|
||||
attr: Attribute,
|
||||
literal: Literal,
|
||||
update: Boolean)
|
||||
: Option[Double] = {
|
||||
update: Boolean): Option[Double] = {
|
||||
|
||||
var percent = 1.0
|
||||
val aColStat = mutableColStats(attrRef.exprId)
|
||||
val ndv = aColStat.distinctCount
|
||||
val colStat = colStatsMap(attr)
|
||||
val statsRange =
|
||||
Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange]
|
||||
Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
|
||||
|
||||
// determine the overlapping degree between predicate range and column's range
|
||||
val literalValueBD = BigDecimal(literal.value.toString)
|
||||
|
@ -463,33 +438,37 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
percent = 1.0
|
||||
} else {
|
||||
// this is partial overlap case
|
||||
var newMax = aColStat.max
|
||||
var newMin = aColStat.min
|
||||
var newNdv = ndv
|
||||
val literalToDouble = literalValueBD.toDouble
|
||||
val maxToDouble = BigDecimal(statsRange.max).toDouble
|
||||
val minToDouble = BigDecimal(statsRange.min).toDouble
|
||||
val literalDouble = literalValueBD.toDouble
|
||||
val maxDouble = BigDecimal(statsRange.max).toDouble
|
||||
val minDouble = BigDecimal(statsRange.min).toDouble
|
||||
|
||||
// Without advanced statistics like histogram, we assume uniform data distribution.
|
||||
// We just prorate the adjusted range over the initial range to compute filter selectivity.
|
||||
// For ease of computation, we convert all relevant numeric values to Double.
|
||||
percent = op match {
|
||||
case _: LessThan =>
|
||||
(literalToDouble - minToDouble) / (maxToDouble - minToDouble)
|
||||
(literalDouble - minDouble) / (maxDouble - minDouble)
|
||||
case _: LessThanOrEqual =>
|
||||
if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble
|
||||
else (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
|
||||
if (literalValueBD == BigDecimal(statsRange.min)) {
|
||||
1.0 / colStat.distinctCount.toDouble
|
||||
} else {
|
||||
(literalDouble - minDouble) / (maxDouble - minDouble)
|
||||
}
|
||||
case _: GreaterThan =>
|
||||
(maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
|
||||
(maxDouble - literalDouble) / (maxDouble - minDouble)
|
||||
case _: GreaterThanOrEqual =>
|
||||
if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble
|
||||
else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
|
||||
if (literalValueBD == BigDecimal(statsRange.max)) {
|
||||
1.0 / colStat.distinctCount.toDouble
|
||||
} else {
|
||||
(maxDouble - literalDouble) / (maxDouble - minDouble)
|
||||
}
|
||||
}
|
||||
|
||||
// Need to save new min/max using the external type value of the literal
|
||||
val newValue = convertBoundValue(attrRef.dataType, literal.value)
|
||||
|
||||
if (update) {
|
||||
// Need to save new min/max using the external type value of the literal
|
||||
val newValue = convertBoundValue(attr.dataType, literal.value)
|
||||
var newMax = colStat.max
|
||||
var newMin = colStat.min
|
||||
op match {
|
||||
case _: GreaterThan => newMin = newValue
|
||||
case _: GreaterThanOrEqual => newMin = newValue
|
||||
|
@ -497,11 +476,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
case _: LessThanOrEqual => newMax = newValue
|
||||
}
|
||||
|
||||
newNdv = math.max(math.round(ndv.toDouble * percent), 1)
|
||||
val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
|
||||
val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1)
|
||||
val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
|
||||
max = newMax, nullCount = 0)
|
||||
|
||||
mutableColStats += (attrRef.exprId -> newStats)
|
||||
colStatsMap(attr) = newStats
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -509,3 +488,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
class ColumnStatsMap {
|
||||
private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty
|
||||
|
||||
def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = {
|
||||
baseMap.clear()
|
||||
baseMap ++= colStats.baseMap
|
||||
}
|
||||
|
||||
def contains(a: Attribute): Boolean = baseMap.contains(a.exprId)
|
||||
|
||||
def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2
|
||||
|
||||
def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats)
|
||||
|
||||
def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq)
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
case _ if !rowCountsExist(conf, join.left, join.right) =>
|
||||
None
|
||||
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
|
||||
// 1. Compute join selectivity
|
||||
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
|
||||
val selectivity = joinSelectivity(joinKeyPairs)
|
||||
|
@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) {
|
||||
// The output is empty, we don't need to keep column stats.
|
||||
Nil
|
||||
} else if (innerJoinedRows == 0) {
|
||||
} else if (selectivity == 0) {
|
||||
joinType match {
|
||||
// For outer joins, if the inner join part is empty, the number of output rows is the
|
||||
// For outer joins, if the join selectivity is 0, the number of output rows is the
|
||||
// same as that of the outer side. And column stats of join keys from the outer side
|
||||
// keep unchanged, while column stats of join keys from the other side should be updated
|
||||
// based on added null values.
|
||||
|
@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
}
|
||||
case _ => Nil
|
||||
}
|
||||
} else if (selectivity == 1) {
|
||||
// Cartesian product, just propagate the original column stats
|
||||
inputAttrStats.toSeq
|
||||
} else {
|
||||
val joinKeyStats = getIntersectedStats(joinKeyPairs)
|
||||
join.joinType match {
|
||||
|
@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
Some(Statistics(
|
||||
sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats),
|
||||
rowCount = Some(outputRows),
|
||||
attributeStats = outputAttrStats,
|
||||
isBroadcastable = false))
|
||||
attributeStats = outputAttrStats))
|
||||
|
||||
case _ =>
|
||||
// When there is no equi-join condition, we do estimation like cartesian product.
|
||||
|
@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
Some(Statistics(
|
||||
sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats),
|
||||
rowCount = Some(outputRows),
|
||||
attributeStats = inputAttrStats,
|
||||
isBroadcastable = false))
|
||||
attributeStats = inputAttrStats))
|
||||
}
|
||||
|
||||
// scalastyle:off
|
||||
|
@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
}
|
||||
|
||||
if (ndvDenom < 0) {
|
||||
// There isn't join keys or column stats for any of the join key pairs, we do estimation like
|
||||
// cartesian product.
|
||||
// We can't find any join key pairs with column stats, estimate it as cartesian join.
|
||||
1
|
||||
} else if (ndvDenom == 0) {
|
||||
// One of the join key pairs is disjoint, thus the two sides of join is disjoint.
|
||||
|
@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
|
||||
/**
|
||||
* Propagate or update column stats for output attributes.
|
||||
* 1. For cartesian product, all values are preserved, so there's no need to change column stats.
|
||||
* 2. For other cases, a) update max/min of join keys based on their intersected range. b) update
|
||||
* distinct count of other attributes based on output rows after join.
|
||||
*/
|
||||
private def updateAttrStats(
|
||||
outputRows: BigInt,
|
||||
|
@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
|
||||
val leftRows = leftStats.rowCount.get
|
||||
val rightRows = rightStats.rowCount.get
|
||||
if (outputRows == leftRows * rightRows) {
|
||||
// Cartesian product, just propagate the original column stats
|
||||
attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a))
|
||||
} else {
|
||||
val leftRatio =
|
||||
if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0)
|
||||
val rightRatio =
|
||||
if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0)
|
||||
attributes.foreach { a =>
|
||||
// check if this attribute is a join key
|
||||
if (joinKeyStats.contains(a)) {
|
||||
outputAttrStats += a -> joinKeyStats(a)
|
||||
|
||||
attributes.foreach { a =>
|
||||
// check if this attribute is a join key
|
||||
if (joinKeyStats.contains(a)) {
|
||||
outputAttrStats += a -> joinKeyStats(a)
|
||||
} else {
|
||||
val leftRatio = if (leftRows != 0) {
|
||||
BigDecimal(outputRows) / BigDecimal(leftRows)
|
||||
} else {
|
||||
val oldColStat = oldAttrStats(a)
|
||||
val oldNdv = oldColStat.distinctCount
|
||||
// We only change (scale down) the number of distinct values if the number of rows
|
||||
// decreases after join, because join won't produce new values even if the number of
|
||||
// rows increases.
|
||||
val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
|
||||
ceil(BigDecimal(oldNdv) * leftRatio)
|
||||
} else if (join.right.outputSet.contains(a) && rightRatio < 1) {
|
||||
ceil(BigDecimal(oldNdv) * rightRatio)
|
||||
} else {
|
||||
oldNdv
|
||||
}
|
||||
// TODO: support nullCount updates for specific outer joins
|
||||
outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
|
||||
BigDecimal(0)
|
||||
}
|
||||
val rightRatio = if (rightRows != 0) {
|
||||
BigDecimal(outputRows) / BigDecimal(rightRows)
|
||||
} else {
|
||||
BigDecimal(0)
|
||||
}
|
||||
val oldColStat = oldAttrStats(a)
|
||||
val oldNdv = oldColStat.distinctCount
|
||||
// We only change (scale down) the number of distinct values if the number of rows
|
||||
// decreases after join, because join won't produce new values even if the number of
|
||||
// rows increases.
|
||||
val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
|
||||
ceil(BigDecimal(oldNdv) * leftRatio)
|
||||
} else if (join.right.outputSet.contains(a) && rightRatio < 1) {
|
||||
ceil(BigDecimal(oldNdv) * rightRatio)
|
||||
} else {
|
||||
oldNdv
|
||||
}
|
||||
// TODO: support nullCount updates for specific outer joins
|
||||
outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
|
||||
}
|
||||
|
||||
}
|
||||
outputAttrStats
|
||||
}
|
||||
|
@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
|
|||
|
||||
// Update intersected column stats
|
||||
assert(leftKey.dataType.sameType(rightKey.dataType))
|
||||
val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
|
||||
val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
|
||||
val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType)
|
||||
intersectedStats.put(leftKey,
|
||||
leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0))
|
||||
intersectedStats.put(rightKey,
|
||||
rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0))
|
||||
val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen)
|
||||
val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2
|
||||
val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
|
||||
|
||||
intersectedStats.put(leftKey, newStats)
|
||||
intersectedStats.put(rightKey, newStats)
|
||||
}
|
||||
AttributeMap(intersectedStats.toSeq)
|
||||
}
|
||||
|
@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) {
|
|||
Some(Statistics(
|
||||
sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats),
|
||||
rowCount = Some(outputRows),
|
||||
attributeStats = leftStats.attributeStats,
|
||||
isBroadcastable = false))
|
||||
attributeStats = leftStats.attributeStats))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
|
|
@ -26,19 +26,33 @@ import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
|
|||
|
||||
|
||||
/** Value range of a column. */
|
||||
trait Range
|
||||
trait Range {
|
||||
def contains(l: Literal): Boolean
|
||||
}
|
||||
|
||||
/** For simplicity we use decimal to unify operations of numeric ranges. */
|
||||
case class NumericRange(min: JDecimal, max: JDecimal) extends Range
|
||||
case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
|
||||
override def contains(l: Literal): Boolean = {
|
||||
val decimal = l.dataType match {
|
||||
case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
|
||||
case _ => new JDecimal(l.value.toString)
|
||||
}
|
||||
min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This version of Spark does not have min/max for binary/string types, we define their default
|
||||
* behaviors by this class.
|
||||
*/
|
||||
class DefaultRange extends Range
|
||||
class DefaultRange extends Range {
|
||||
override def contains(l: Literal): Boolean = true
|
||||
}
|
||||
|
||||
/** This is for columns with only null values. */
|
||||
class NullRange extends Range
|
||||
class NullRange extends Range {
|
||||
override def contains(l: Literal): Boolean = false
|
||||
}
|
||||
|
||||
object Range {
|
||||
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
|
||||
|
@ -58,20 +72,6 @@ object Range {
|
|||
n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0
|
||||
}
|
||||
|
||||
def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match {
|
||||
case _: DefaultRange => true
|
||||
case _: NullRange => false
|
||||
case n: NumericRange =>
|
||||
val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) {
|
||||
if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
|
||||
} else {
|
||||
assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] ||
|
||||
lit.dataType.isInstanceOf[TimestampType])
|
||||
new JDecimal(lit.value.toString)
|
||||
}
|
||||
n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Intersected results of two ranges. This is only for two overlapped ranges.
|
||||
* The outputs are the intersected min/max values.
|
||||
|
|
|
@ -17,12 +17,11 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.statsEstimation
|
||||
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.sql.Date
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -38,6 +37,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4)
|
||||
|
||||
// only 2 values
|
||||
val arBool = AttributeReference("cbool", BooleanType)()
|
||||
val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
|
||||
nullCount = 0, avgLen = 1, maxLen = 1)
|
||||
|
||||
// Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
|
||||
val dMin = Date.valueOf("2017-01-01")
|
||||
val dMax = Date.valueOf("2017-01-10")
|
||||
|
@ -45,14 +49,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4)
|
||||
|
||||
// Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through
|
||||
// "2017-01-01 10:00:00" for 10 distinct timestamps (or hours).
|
||||
val tsMin = Timestamp.valueOf("2017-01-01 01:00:00")
|
||||
val tsMax = Timestamp.valueOf("2017-01-01 10:00:00")
|
||||
val arTimestamp = AttributeReference("ctimestamp", TimestampType)()
|
||||
val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax),
|
||||
nullCount = 0, avgLen = 8, maxLen = 8)
|
||||
|
||||
// Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
|
||||
val decMin = new java.math.BigDecimal("0.200000000000000000")
|
||||
val decMax = new java.math.BigDecimal("0.800000000000000000")
|
||||
|
@ -77,8 +73,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(1L)
|
||||
)
|
||||
1)
|
||||
}
|
||||
|
||||
test("cint <=> 2") {
|
||||
validateEstimatedStats(
|
||||
arInt,
|
||||
Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
1)
|
||||
}
|
||||
|
||||
test("cint = 0") {
|
||||
|
@ -88,8 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(0L)
|
||||
)
|
||||
0)
|
||||
}
|
||||
|
||||
test("cint < 3") {
|
||||
|
@ -98,8 +101,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("cint < 0") {
|
||||
|
@ -109,8 +111,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(0L)
|
||||
)
|
||||
0)
|
||||
}
|
||||
|
||||
test("cint <= 3") {
|
||||
|
@ -119,8 +120,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("cint > 6") {
|
||||
|
@ -129,8 +129,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(5L)
|
||||
)
|
||||
5)
|
||||
}
|
||||
|
||||
test("cint > 10") {
|
||||
|
@ -140,8 +139,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(0L)
|
||||
)
|
||||
0)
|
||||
}
|
||||
|
||||
test("cint >= 6") {
|
||||
|
@ -150,8 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(5L)
|
||||
)
|
||||
5)
|
||||
}
|
||||
|
||||
test("cint IS NULL") {
|
||||
|
@ -160,8 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 0, min = None, max = None,
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(0L)
|
||||
)
|
||||
0)
|
||||
}
|
||||
|
||||
test("cint IS NOT NULL") {
|
||||
|
@ -170,8 +166,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(10L)
|
||||
)
|
||||
10)
|
||||
}
|
||||
|
||||
test("cint > 3 AND cint <= 6") {
|
||||
|
@ -181,8 +176,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(4L)
|
||||
)
|
||||
4)
|
||||
}
|
||||
|
||||
test("cint = 3 OR cint = 6") {
|
||||
|
@ -192,8 +186,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(2L)
|
||||
)
|
||||
2)
|
||||
}
|
||||
|
||||
test("cint IN (3, 4, 5)") {
|
||||
|
@ -202,8 +195,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("cint NOT IN (3, 4, 5)") {
|
||||
|
@ -212,8 +204,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(7L)
|
||||
)
|
||||
7)
|
||||
}
|
||||
|
||||
test("cbool = true") {
|
||||
validateEstimatedStats(
|
||||
arBool,
|
||||
Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)),
|
||||
ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
|
||||
nullCount = 0, avgLen = 1, maxLen = 1),
|
||||
5)
|
||||
}
|
||||
|
||||
test("cbool > false") {
|
||||
// bool comparison is not supported yet, so stats remain same.
|
||||
validateEstimatedStats(
|
||||
arBool,
|
||||
Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)),
|
||||
ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
|
||||
nullCount = 0, avgLen = 1, maxLen = 1),
|
||||
10)
|
||||
}
|
||||
|
||||
test("cdate = cast('2017-01-02' AS DATE)") {
|
||||
|
@ -224,8 +234,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
childStatsTestPlan(Seq(arDate), 10L)),
|
||||
ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(1L)
|
||||
)
|
||||
1)
|
||||
}
|
||||
|
||||
test("cdate < cast('2017-01-03' AS DATE)") {
|
||||
|
@ -236,8 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
childStatsTestPlan(Seq(arDate), 10L)),
|
||||
ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("""cdate IN ( cast('2017-01-03' AS DATE),
|
||||
|
@ -251,32 +259,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
childStatsTestPlan(Seq(arDate), 10L)),
|
||||
ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(3L)
|
||||
)
|
||||
}
|
||||
|
||||
test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") {
|
||||
val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00")
|
||||
validateEstimatedStats(
|
||||
arTimestamp,
|
||||
Filter(EqualTo(arTimestamp, Literal(ts2017010102)),
|
||||
childStatsTestPlan(Seq(arTimestamp), 10L)),
|
||||
ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102),
|
||||
nullCount = 0, avgLen = 8, maxLen = 8),
|
||||
Some(1L)
|
||||
)
|
||||
}
|
||||
|
||||
test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") {
|
||||
val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00")
|
||||
validateEstimatedStats(
|
||||
arTimestamp,
|
||||
Filter(LessThan(arTimestamp, Literal(ts2017010103)),
|
||||
childStatsTestPlan(Seq(arTimestamp), 10L)),
|
||||
ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103),
|
||||
nullCount = 0, avgLen = 8, maxLen = 8),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("cdecimal = 0.400000000000000000") {
|
||||
|
@ -287,20 +270,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
childStatsTestPlan(Seq(arDecimal), 4L)),
|
||||
ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
|
||||
nullCount = 0, avgLen = 8, maxLen = 8),
|
||||
Some(1L)
|
||||
)
|
||||
1)
|
||||
}
|
||||
|
||||
test("cdecimal < 0.60 ") {
|
||||
val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
|
||||
validateEstimatedStats(
|
||||
arDecimal,
|
||||
Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))),
|
||||
Filter(LessThan(arDecimal, Literal(dec_0_60)),
|
||||
childStatsTestPlan(Seq(arDecimal), 4L)),
|
||||
ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
|
||||
nullCount = 0, avgLen = 8, maxLen = 8),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("cdouble < 3.0") {
|
||||
|
@ -309,8 +290,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
|
||||
ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
|
||||
nullCount = 0, avgLen = 8, maxLen = 8),
|
||||
Some(3L)
|
||||
)
|
||||
3)
|
||||
}
|
||||
|
||||
test("cstring = 'A2'") {
|
||||
|
@ -319,8 +299,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
|
||||
ColumnStat(distinctCount = 1, min = None, max = None,
|
||||
nullCount = 0, avgLen = 2, maxLen = 2),
|
||||
Some(1L)
|
||||
)
|
||||
1)
|
||||
}
|
||||
|
||||
// There is no min/max statistics for String type. We estimate 10 rows returned.
|
||||
|
@ -330,8 +309,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
|
||||
ColumnStat(distinctCount = 10, min = None, max = None,
|
||||
nullCount = 0, avgLen = 2, maxLen = 2),
|
||||
Some(10L)
|
||||
)
|
||||
10)
|
||||
}
|
||||
|
||||
// This is a corner test case. We want to test if we can handle the case when the number of
|
||||
|
@ -351,8 +329,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
|
||||
ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
|
||||
nullCount = 0, avgLen = 4, maxLen = 4),
|
||||
Some(2L)
|
||||
)
|
||||
2)
|
||||
}
|
||||
|
||||
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
|
||||
|
@ -361,8 +338,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
rowCount = tableRowCount,
|
||||
attributeStats = AttributeMap(Seq(
|
||||
arInt -> childColStatInt,
|
||||
arBool -> childColStatBool,
|
||||
arDate -> childColStatDate,
|
||||
arTimestamp -> childColStatTimestamp,
|
||||
arDecimal -> childColStatDecimal,
|
||||
arDouble -> childColStatDouble,
|
||||
arString -> childColStatString
|
||||
|
@ -374,46 +351,31 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
ar: AttributeReference,
|
||||
filterNode: Filter,
|
||||
expectedColStats: ColumnStat,
|
||||
rowCount: Option[BigInt] = None)
|
||||
: Unit = {
|
||||
rowCount: Int): Unit = {
|
||||
|
||||
val expectedRowCount: BigInt = rowCount.getOrElse(0L)
|
||||
val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
|
||||
val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats)
|
||||
val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats)
|
||||
|
||||
val filteredStats = filterNode.stats(conf)
|
||||
assert(filteredStats.sizeInBytes == expectedSizeInBytes)
|
||||
assert(filteredStats.rowCount == rowCount)
|
||||
ar.dataType match {
|
||||
case DecimalType() =>
|
||||
// Due to the internal transformation for DecimalType within engine, the new min/max
|
||||
// in ColumnStat may have a different structure even it contains the right values.
|
||||
// We convert them to Java BigDecimal values so that we can compare the entire object.
|
||||
val generatedColumnStats = filteredStats.attributeStats(ar)
|
||||
val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString)
|
||||
val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString)
|
||||
val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax))
|
||||
assert(outputColStats == expectedColStats)
|
||||
case _ =>
|
||||
// For all other SQL types, we compare the entire object directly.
|
||||
assert(filteredStats.attributeStats(ar) == expectedColStats)
|
||||
}
|
||||
assert(filteredStats.rowCount.get == rowCount)
|
||||
assert(filteredStats.attributeStats(ar) == expectedColStats)
|
||||
|
||||
// If the filter has a binary operator (including those nested inside
|
||||
// AND/OR/NOT), swap the sides of the attribte and the literal, reverse the
|
||||
// operator, and then check again.
|
||||
val rewrittenFilter = filterNode transformExpressionsDown {
|
||||
case op @ EqualTo(ar: AttributeReference, l: Literal) =>
|
||||
case EqualTo(ar: AttributeReference, l: Literal) =>
|
||||
EqualTo(l, ar)
|
||||
|
||||
case op @ LessThan(ar: AttributeReference, l: Literal) =>
|
||||
case LessThan(ar: AttributeReference, l: Literal) =>
|
||||
GreaterThan(l, ar)
|
||||
case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
|
||||
case LessThanOrEqual(ar: AttributeReference, l: Literal) =>
|
||||
GreaterThanOrEqual(l, ar)
|
||||
|
||||
case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
|
||||
case GreaterThan(ar: AttributeReference, l: Literal) =>
|
||||
LessThan(l, ar)
|
||||
case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
|
||||
case GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
|
||||
LessThanOrEqual(l, ar)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue