[SPARK-20318][SQL] Use Catalyst type for min/max in ColumnStat for ease of estimation
## What changes were proposed in this pull request? Currently when estimating predicates like col > literal or col = literal, we will update min or max in column stats based on literal value. However, literal value is of Catalyst type (internal type), while min/max is of external type. Then for the next predicate, we again need to do type conversion to compare and update column stats. This is awkward and causes many unnecessary conversions in estimation. To solve this, we use Catalyst type for min/max in `ColumnStat`. Note that the persistent format in metastore is still of external type, so there's no inconsistency for statistics in metastore. This pr also fixes a bug for boolean type in `IN` condition. ## How was this patch tested? The changes for ColumnStat are covered by existing tests. For bug fix, a new test for boolean type in IN condition is added Author: wangzhenhua <wangzhenhua@huawei.com> Closes #17630 from wzhfy/refactorColumnStat.
This commit is contained in:
parent
7536e2849d
commit
fb036c4413
|
@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.sql.{AnalysisException, Row}
|
import org.apache.spark.sql.{AnalysisException, Row}
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||||
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
@ -74,11 +75,10 @@ case class Statistics(
|
||||||
* Statistics collected for a column.
|
* Statistics collected for a column.
|
||||||
*
|
*
|
||||||
* 1. Supported data types are defined in `ColumnStat.supportsType`.
|
* 1. Supported data types are defined in `ColumnStat.supportsType`.
|
||||||
* 2. The JVM data type stored in min/max is the external data type (used in Row) for the
|
* 2. The JVM data type stored in min/max is the internal data type for the corresponding
|
||||||
* corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
|
* Catalyst data type. For example, the internal type of DateType is Int, and that the internal
|
||||||
* TimestampType we store java.sql.Timestamp.
|
* type of TimestampType is Long.
|
||||||
* 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
|
* 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms
|
||||||
* 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
|
|
||||||
* (sketches) might have been used, and the data collected can also be stale.
|
* (sketches) might have been used, and the data collected can also be stale.
|
||||||
*
|
*
|
||||||
* @param distinctCount number of distinct values
|
* @param distinctCount number of distinct values
|
||||||
|
@ -104,22 +104,43 @@ case class ColumnStat(
|
||||||
/**
|
/**
|
||||||
* Returns a map from string to string that can be used to serialize the column stats.
|
* Returns a map from string to string that can be used to serialize the column stats.
|
||||||
* The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
|
* The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
|
||||||
* representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
|
* representation for the value. min/max values are converted to the external data type. For
|
||||||
|
* example, for DateType we store java.sql.Date, and for TimestampType we store
|
||||||
|
* java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]].
|
||||||
*
|
*
|
||||||
* As part of the protocol, the returned map always contains a key called "version".
|
* As part of the protocol, the returned map always contains a key called "version".
|
||||||
* In the case min/max values are null (None), they won't appear in the map.
|
* In the case min/max values are null (None), they won't appear in the map.
|
||||||
*/
|
*/
|
||||||
def toMap: Map[String, String] = {
|
def toMap(colName: String, dataType: DataType): Map[String, String] = {
|
||||||
val map = new scala.collection.mutable.HashMap[String, String]
|
val map = new scala.collection.mutable.HashMap[String, String]
|
||||||
map.put(ColumnStat.KEY_VERSION, "1")
|
map.put(ColumnStat.KEY_VERSION, "1")
|
||||||
map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
|
map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
|
||||||
map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
|
map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
|
||||||
map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
|
map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
|
||||||
map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
|
map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
|
||||||
min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
|
min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) }
|
||||||
max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
|
max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) }
|
||||||
map.toMap
|
map.toMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given value from Catalyst data type to string representation of external
|
||||||
|
* data type.
|
||||||
|
*/
|
||||||
|
private def toExternalString(v: Any, colName: String, dataType: DataType): String = {
|
||||||
|
val externalValue = dataType match {
|
||||||
|
case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int])
|
||||||
|
case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long])
|
||||||
|
case BooleanType | _: IntegralType | FloatType | DoubleType => v
|
||||||
|
case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
|
||||||
|
// This version of Spark does not use min/max for binary/string types so we ignore it.
|
||||||
|
case _ =>
|
||||||
|
throw new AnalysisException("Column statistics deserialization is not supported for " +
|
||||||
|
s"column $colName of data type: $dataType.")
|
||||||
|
}
|
||||||
|
externalValue.toString
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,28 +171,15 @@ object ColumnStat extends Logging {
|
||||||
* Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
|
* Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
|
||||||
* from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
|
* from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
|
||||||
*/
|
*/
|
||||||
def fromMap(table: String, field: StructField, map: Map[String, String])
|
def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = {
|
||||||
: Option[ColumnStat] = {
|
|
||||||
val str2val: (String => Any) = field.dataType match {
|
|
||||||
case _: IntegralType => _.toLong
|
|
||||||
case _: DecimalType => new java.math.BigDecimal(_)
|
|
||||||
case DoubleType | FloatType => _.toDouble
|
|
||||||
case BooleanType => _.toBoolean
|
|
||||||
case DateType => java.sql.Date.valueOf
|
|
||||||
case TimestampType => java.sql.Timestamp.valueOf
|
|
||||||
// This version of Spark does not use min/max for binary/string types so we ignore it.
|
|
||||||
case BinaryType | StringType => _ => null
|
|
||||||
case _ =>
|
|
||||||
throw new AnalysisException("Column statistics deserialization is not supported for " +
|
|
||||||
s"column ${field.name} of data type: ${field.dataType}.")
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Some(ColumnStat(
|
Some(ColumnStat(
|
||||||
distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
|
distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
|
||||||
// Note that flatMap(Option.apply) turns Option(null) into None.
|
// Note that flatMap(Option.apply) turns Option(null) into None.
|
||||||
min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
|
min = map.get(KEY_MIN_VALUE)
|
||||||
max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
|
.map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
|
||||||
|
max = map.get(KEY_MAX_VALUE)
|
||||||
|
.map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
|
||||||
nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
|
nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
|
||||||
avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
|
avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
|
||||||
maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
|
maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
|
||||||
|
@ -183,6 +191,30 @@ object ColumnStat extends Logging {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts from string representation of external data type to the corresponding Catalyst data
|
||||||
|
* type.
|
||||||
|
*/
|
||||||
|
private def fromExternalString(s: String, name: String, dataType: DataType): Any = {
|
||||||
|
dataType match {
|
||||||
|
case BooleanType => s.toBoolean
|
||||||
|
case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s))
|
||||||
|
case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s))
|
||||||
|
case ByteType => s.toByte
|
||||||
|
case ShortType => s.toShort
|
||||||
|
case IntegerType => s.toInt
|
||||||
|
case LongType => s.toLong
|
||||||
|
case FloatType => s.toFloat
|
||||||
|
case DoubleType => s.toDouble
|
||||||
|
case _: DecimalType => Decimal(s)
|
||||||
|
// This version of Spark does not use min/max for binary/string types so we ignore it.
|
||||||
|
case BinaryType | StringType => null
|
||||||
|
case _ =>
|
||||||
|
throw new AnalysisException("Column statistics deserialization is not supported for " +
|
||||||
|
s"column $name of data type: $dataType.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs an expression to compute column statistics for a given column.
|
* Constructs an expression to compute column statistics for a given column.
|
||||||
*
|
*
|
||||||
|
@ -232,11 +264,14 @@ object ColumnStat extends Logging {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
|
/** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
|
||||||
def rowToColumnStat(row: Row): ColumnStat = {
|
def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = {
|
||||||
ColumnStat(
|
ColumnStat(
|
||||||
distinctCount = BigInt(row.getLong(0)),
|
distinctCount = BigInt(row.getLong(0)),
|
||||||
min = Option(row.get(1)), // for string/binary min/max, get should return null
|
// for string/binary min/max, get should return null
|
||||||
max = Option(row.get(2)),
|
min = Option(row.get(1))
|
||||||
|
.map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
|
||||||
|
max = Option(row.get(2))
|
||||||
|
.map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
|
||||||
nullCount = BigInt(row.getLong(3)),
|
nullCount = BigInt(row.getLong(3)),
|
||||||
avgLen = row.getLong(4),
|
avgLen = row.getLong(4),
|
||||||
maxLen = row.getLong(5)
|
maxLen = row.getLong(5)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
|
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.types.{DataType, StringType}
|
import org.apache.spark.sql.types.{DecimalType, _}
|
||||||
|
|
||||||
|
|
||||||
object EstimationUtils {
|
object EstimationUtils {
|
||||||
|
@ -75,4 +75,32 @@ object EstimationUtils {
|
||||||
// (simple computation of statistics returns product of children).
|
// (simple computation of statistics returns product of children).
|
||||||
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
|
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For simplicity we use Decimal to unify operations for data types whose min/max values can be
|
||||||
|
* represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true).
|
||||||
|
* The two methods below are the contract of conversion.
|
||||||
|
*/
|
||||||
|
def toDecimal(value: Any, dataType: DataType): Decimal = {
|
||||||
|
dataType match {
|
||||||
|
case _: NumericType | DateType | TimestampType => Decimal(value.toString)
|
||||||
|
case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def fromDecimal(dec: Decimal, dataType: DataType): Any = {
|
||||||
|
dataType match {
|
||||||
|
case BooleanType => dec.toLong == 1
|
||||||
|
case DateType => dec.toInt
|
||||||
|
case TimestampType => dec.toLong
|
||||||
|
case ByteType => dec.toByte
|
||||||
|
case ShortType => dec.toShort
|
||||||
|
case IntegerType => dec.toInt
|
||||||
|
case LongType => dec.toLong
|
||||||
|
case FloatType => dec.toFloat
|
||||||
|
case DoubleType => dec.toDouble
|
||||||
|
case _: DecimalType => dec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
|
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -301,30 +300,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* For a SQL data type, its internal data type may be different from its external type.
|
|
||||||
* For DateType, its internal type is Int, and its external data type is Java Date type.
|
|
||||||
* The min/max values in ColumnStat are saved in their corresponding external type.
|
|
||||||
*
|
|
||||||
* @param attrDataType the column data type
|
|
||||||
* @param litValue the literal value
|
|
||||||
* @return a BigDecimal value
|
|
||||||
*/
|
|
||||||
def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = {
|
|
||||||
attrDataType match {
|
|
||||||
case DateType =>
|
|
||||||
Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
|
|
||||||
case TimestampType =>
|
|
||||||
Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
|
|
||||||
case _: DecimalType =>
|
|
||||||
Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
|
|
||||||
case StringType | BinaryType =>
|
|
||||||
None
|
|
||||||
case _ =>
|
|
||||||
Some(litValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a percentage of rows meeting an equality (=) expression.
|
* Returns a percentage of rows meeting an equality (=) expression.
|
||||||
* This method evaluates the equality predicate for all data types.
|
* This method evaluates the equality predicate for all data types.
|
||||||
|
@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
|
||||||
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
|
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
|
||||||
if (statsRange.contains(literal)) {
|
if (statsRange.contains(literal)) {
|
||||||
if (update) {
|
if (update) {
|
||||||
// We update ColumnStat structure after apply this equality predicate.
|
// We update ColumnStat structure after apply this equality predicate:
|
||||||
// Set distinctCount to 1. Set nullCount to 0.
|
// Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
|
||||||
// Need to save new min/max using the external type value of the literal
|
// value.
|
||||||
val newValue = convertBoundValue(attr.dataType, literal.value)
|
val newStats = attr.dataType match {
|
||||||
val newStats = colStat.copy(distinctCount = 1, min = newValue,
|
case StringType | BinaryType =>
|
||||||
max = newValue, nullCount = 0)
|
colStat.copy(distinctCount = 1, nullCount = 0)
|
||||||
|
case _ =>
|
||||||
|
colStat.copy(distinctCount = 1, min = Some(literal.value),
|
||||||
|
max = Some(literal.value), nullCount = 0)
|
||||||
|
}
|
||||||
colStatsMap(attr) = newStats
|
colStatsMap(attr) = newStats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
|
||||||
return Some(0.0)
|
return Some(0.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need to save new min/max using the external type value of the literal
|
val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType))
|
||||||
val newMax = convertBoundValue(
|
val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType))
|
||||||
attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
|
|
||||||
val newMin = convertBoundValue(
|
|
||||||
attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
|
|
||||||
|
|
||||||
// newNdv should not be greater than the old ndv. For example, column has only 2 values
|
// newNdv should not be greater than the old ndv. For example, column has only 2 values
|
||||||
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
|
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
|
||||||
newNdv = ndv.min(BigInt(validQuerySet.size))
|
newNdv = ndv.min(BigInt(validQuerySet.size))
|
||||||
if (update) {
|
if (update) {
|
||||||
val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
|
val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin),
|
||||||
max = newMax, nullCount = 0)
|
max = Some(newMax), nullCount = 0)
|
||||||
colStatsMap(attr) = newStats
|
colStatsMap(attr) = newStats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
|
||||||
|
|
||||||
val colStat = colStatsMap(attr)
|
val colStat = colStatsMap(attr)
|
||||||
val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
|
val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
|
||||||
val max = BigDecimal(statsRange.max)
|
val max = statsRange.max.toBigDecimal
|
||||||
val min = BigDecimal(statsRange.min)
|
val min = statsRange.min.toBigDecimal
|
||||||
val ndv = BigDecimal(colStat.distinctCount)
|
val ndv = BigDecimal(colStat.distinctCount)
|
||||||
|
|
||||||
// determine the overlapping degree between predicate range and column's range
|
// determine the overlapping degree between predicate range and column's range
|
||||||
|
@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
|
||||||
}
|
}
|
||||||
|
|
||||||
if (update) {
|
if (update) {
|
||||||
// Need to save new min/max using the external type value of the literal
|
val newValue = Some(literal.value)
|
||||||
val newValue = convertBoundValue(attr.dataType, literal.value)
|
|
||||||
var newMax = colStat.max
|
var newMax = colStat.max
|
||||||
var newMin = colStat.min
|
var newMin = colStat.min
|
||||||
var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
|
var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
|
||||||
|
@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
|
||||||
val colStatLeft = colStatsMap(attrLeft)
|
val colStatLeft = colStatsMap(attrLeft)
|
||||||
val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
|
val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
|
||||||
.asInstanceOf[NumericRange]
|
.asInstanceOf[NumericRange]
|
||||||
val maxLeft = BigDecimal(statsRangeLeft.max)
|
val maxLeft = statsRangeLeft.max
|
||||||
val minLeft = BigDecimal(statsRangeLeft.min)
|
val minLeft = statsRangeLeft.min
|
||||||
|
|
||||||
val colStatRight = colStatsMap(attrRight)
|
val colStatRight = colStatsMap(attrRight)
|
||||||
val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
|
val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
|
||||||
.asInstanceOf[NumericRange]
|
.asInstanceOf[NumericRange]
|
||||||
val maxRight = BigDecimal(statsRangeRight.max)
|
val maxRight = statsRangeRight.max
|
||||||
val minRight = BigDecimal(statsRangeRight.min)
|
val minRight = statsRangeRight.min
|
||||||
|
|
||||||
// determine the overlapping degree between predicate range and column's range
|
// determine the overlapping degree between predicate range and column's range
|
||||||
val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
|
val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
|
||||||
|
|
|
@ -17,12 +17,8 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
|
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
|
||||||
|
|
||||||
import java.math.{BigDecimal => JDecimal}
|
|
||||||
import java.sql.{Date, Timestamp}
|
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.Literal
|
import org.apache.spark.sql.catalyst.expressions.Literal
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
|
|
||||||
|
|
||||||
|
|
||||||
/** Value range of a column. */
|
/** Value range of a column. */
|
||||||
|
@ -31,13 +27,10 @@ trait Range {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** For simplicity we use decimal to unify operations of numeric ranges. */
|
/** For simplicity we use decimal to unify operations of numeric ranges. */
|
||||||
case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
|
case class NumericRange(min: Decimal, max: Decimal) extends Range {
|
||||||
override def contains(l: Literal): Boolean = {
|
override def contains(l: Literal): Boolean = {
|
||||||
val decimal = l.dataType match {
|
val lit = EstimationUtils.toDecimal(l.value, l.dataType)
|
||||||
case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
|
min <= lit && max >= lit
|
||||||
case _ => new JDecimal(l.value.toString)
|
|
||||||
}
|
|
||||||
min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +51,10 @@ object Range {
|
||||||
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
|
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
|
||||||
case StringType | BinaryType => new DefaultRange()
|
case StringType | BinaryType => new DefaultRange()
|
||||||
case _ if min.isEmpty || max.isEmpty => new NullRange()
|
case _ if min.isEmpty || max.isEmpty => new NullRange()
|
||||||
case _ => toNumericRange(min.get, max.get, dataType)
|
case _ =>
|
||||||
|
NumericRange(
|
||||||
|
min = EstimationUtils.toDecimal(min.get, dataType),
|
||||||
|
max = EstimationUtils.toDecimal(max.get, dataType))
|
||||||
}
|
}
|
||||||
|
|
||||||
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
|
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
|
||||||
|
@ -82,51 +78,11 @@ object Range {
|
||||||
// binary/string types don't support intersecting.
|
// binary/string types don't support intersecting.
|
||||||
(None, None)
|
(None, None)
|
||||||
case (n1: NumericRange, n2: NumericRange) =>
|
case (n1: NumericRange, n2: NumericRange) =>
|
||||||
val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max))
|
// Choose the maximum of two min values, and the minimum of two max values.
|
||||||
val (newMin, newMax) = fromNumericRange(newRange, dt)
|
val newMin = if (n1.min <= n2.min) n2.min else n1.min
|
||||||
(Some(newMin), Some(newMax))
|
val newMax = if (n1.max <= n2.max) n1.max else n2.max
|
||||||
|
(Some(EstimationUtils.fromDecimal(newMin, dt)),
|
||||||
|
Some(EstimationUtils.fromDecimal(newMax, dt)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* For simplicity we use decimal to unify operations of numeric types, the two methods below
|
|
||||||
* are the contract of conversion.
|
|
||||||
*/
|
|
||||||
private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = {
|
|
||||||
dataType match {
|
|
||||||
case _: NumericType =>
|
|
||||||
NumericRange(new JDecimal(min.toString), new JDecimal(max.toString))
|
|
||||||
case BooleanType =>
|
|
||||||
val min1 = if (min.asInstanceOf[Boolean]) 1 else 0
|
|
||||||
val max1 = if (max.asInstanceOf[Boolean]) 1 else 0
|
|
||||||
NumericRange(new JDecimal(min1), new JDecimal(max1))
|
|
||||||
case DateType =>
|
|
||||||
val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date])
|
|
||||||
val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date])
|
|
||||||
NumericRange(new JDecimal(min1), new JDecimal(max1))
|
|
||||||
case TimestampType =>
|
|
||||||
val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp])
|
|
||||||
val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp])
|
|
||||||
NumericRange(new JDecimal(min1), new JDecimal(max1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = {
|
|
||||||
dataType match {
|
|
||||||
case _: IntegralType =>
|
|
||||||
(n.min.longValue(), n.max.longValue())
|
|
||||||
case FloatType | DoubleType =>
|
|
||||||
(n.min.doubleValue(), n.max.doubleValue())
|
|
||||||
case _: DecimalType =>
|
|
||||||
(n.min, n.max)
|
|
||||||
case BooleanType =>
|
|
||||||
(n.min.longValue() == 1, n.max.longValue() == 1)
|
|
||||||
case DateType =>
|
|
||||||
(DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue()))
|
|
||||||
case TimestampType =>
|
|
||||||
(DateTimeUtils.toJavaTimestamp(n.min.longValue()),
|
|
||||||
DateTimeUtils.toJavaTimestamp(n.max.longValue()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
|
||||||
import org.apache.spark.sql.catalyst.plans.LeftOuter
|
import org.apache.spark.sql.catalyst.plans.LeftOuter
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
|
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
||||||
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
nullCount = 0, avgLen = 1, maxLen = 1)
|
nullCount = 0, avgLen = 1, maxLen = 1)
|
||||||
|
|
||||||
// column cdate has 10 values from 2017-01-01 through 2017-01-10.
|
// column cdate has 10 values from 2017-01-01 through 2017-01-10.
|
||||||
val dMin = Date.valueOf("2017-01-01")
|
val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01"))
|
||||||
val dMax = Date.valueOf("2017-01-10")
|
val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10"))
|
||||||
val attrDate = AttributeReference("cdate", DateType)()
|
val attrDate = AttributeReference("cdate", DateType)()
|
||||||
val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
|
val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
|
||||||
nullCount = 0, avgLen = 4, maxLen = 4)
|
nullCount = 0, avgLen = 4, maxLen = 4)
|
||||||
|
|
||||||
// column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
|
// column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
|
||||||
val decMin = new java.math.BigDecimal("0.200000000000000000")
|
val decMin = Decimal("0.200000000000000000")
|
||||||
val decMax = new java.math.BigDecimal("0.800000000000000000")
|
val decMax = Decimal("0.800000000000000000")
|
||||||
val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
|
val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
|
||||||
val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
|
val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
|
||||||
nullCount = 0, avgLen = 8, maxLen = 8)
|
nullCount = 0, avgLen = 8, maxLen = 8)
|
||||||
|
@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
|
|
||||||
test("cint < 3 OR null") {
|
test("cint < 3 OR null") {
|
||||||
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
|
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
|
||||||
val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf)
|
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
|
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
|
||||||
Seq(attrInt -> colStatInt),
|
Seq(attrInt -> colStatInt),
|
||||||
|
@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
expectedRowCount = 7)
|
expectedRowCount = 7)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("cbool IN (true)") {
|
||||||
|
validateEstimatedStats(
|
||||||
|
Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
|
||||||
|
Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
|
||||||
|
nullCount = 0, avgLen = 1, maxLen = 1)),
|
||||||
|
expectedRowCount = 5)
|
||||||
|
}
|
||||||
|
|
||||||
test("cbool = true") {
|
test("cbool = true") {
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
|
Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
|
||||||
|
@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("cdate = cast('2017-01-02' AS DATE)") {
|
test("cdate = cast('2017-01-02' AS DATE)") {
|
||||||
val d20170102 = Date.valueOf("2017-01-02")
|
val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02"))
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(EqualTo(attrDate, Literal(d20170102)),
|
Filter(EqualTo(attrDate, Literal(d20170102, DateType)),
|
||||||
childStatsTestPlan(Seq(attrDate), 10L)),
|
childStatsTestPlan(Seq(attrDate), 10L)),
|
||||||
Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
|
Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
|
||||||
nullCount = 0, avgLen = 4, maxLen = 4)),
|
nullCount = 0, avgLen = 4, maxLen = 4)),
|
||||||
|
@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("cdate < cast('2017-01-03' AS DATE)") {
|
test("cdate < cast('2017-01-03' AS DATE)") {
|
||||||
val d20170103 = Date.valueOf("2017-01-03")
|
val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(LessThan(attrDate, Literal(d20170103)),
|
Filter(LessThan(attrDate, Literal(d20170103, DateType)),
|
||||||
childStatsTestPlan(Seq(attrDate), 10L)),
|
childStatsTestPlan(Seq(attrDate), 10L)),
|
||||||
Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
|
Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
|
||||||
nullCount = 0, avgLen = 4, maxLen = 4)),
|
nullCount = 0, avgLen = 4, maxLen = 4)),
|
||||||
|
@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
|
|
||||||
test("""cdate IN ( cast('2017-01-03' AS DATE),
|
test("""cdate IN ( cast('2017-01-03' AS DATE),
|
||||||
cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
|
cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
|
||||||
val d20170103 = Date.valueOf("2017-01-03")
|
val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
|
||||||
val d20170104 = Date.valueOf("2017-01-04")
|
val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04"))
|
||||||
val d20170105 = Date.valueOf("2017-01-05")
|
val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05"))
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
|
Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType),
|
||||||
childStatsTestPlan(Seq(attrDate), 10L)),
|
Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)),
|
||||||
Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
|
Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
|
||||||
nullCount = 0, avgLen = 4, maxLen = 4)),
|
nullCount = 0, avgLen = 4, maxLen = 4)),
|
||||||
expectedRowCount = 3)
|
expectedRowCount = 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("cdecimal = 0.400000000000000000") {
|
test("cdecimal = 0.400000000000000000") {
|
||||||
val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
|
val dec_0_40 = Decimal("0.400000000000000000")
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
|
Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
|
||||||
childStatsTestPlan(Seq(attrDecimal), 4L)),
|
childStatsTestPlan(Seq(attrDecimal), 4L)),
|
||||||
|
@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("cdecimal < 0.60 ") {
|
test("cdecimal < 0.60 ") {
|
||||||
val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
|
val dec_0_60 = Decimal("0.600000000000000000")
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(LessThan(attrDecimal, Literal(dec_0_60)),
|
Filter(LessThan(attrDecimal, Literal(dec_0_60)),
|
||||||
childStatsTestPlan(Seq(attrDecimal), 4L)),
|
childStatsTestPlan(Seq(attrDecimal), 4L)),
|
||||||
|
@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
||||||
|
|
||||||
test("cint = cint3") {
|
test("cint = cint3") {
|
||||||
// no records qualify due to no overlap
|
// no records qualify due to no overlap
|
||||||
val emptyColStats = Seq[(Attribute, ColumnStat)]()
|
|
||||||
validateEstimatedStats(
|
validateEstimatedStats(
|
||||||
Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
|
Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
|
||||||
Nil, // set to empty
|
Nil, // set to empty
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap,
|
||||||
import org.apache.spark.sql.catalyst.plans._
|
import org.apache.spark.sql.catalyst.plans._
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
|
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
||||||
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
import org.apache.spark.sql.types.{DateType, TimestampType, _}
|
import org.apache.spark.sql.types.{DateType, TimestampType, _}
|
||||||
|
|
||||||
|
|
||||||
|
@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
||||||
test("test join keys of different types") {
|
test("test join keys of different types") {
|
||||||
/** Columns in a table with only one row */
|
/** Columns in a table with only one row */
|
||||||
def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = {
|
def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = {
|
||||||
val dec = new java.math.BigDecimal("1.000000000000000000")
|
val dec = Decimal("1.000000000000000000")
|
||||||
val date = Date.valueOf("2016-05-08")
|
val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
|
||||||
val timestamp = Timestamp.valueOf("2016-05-08 00:00:01")
|
val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
|
||||||
mutable.LinkedHashMap[Attribute, ColumnStat](
|
mutable.LinkedHashMap[Attribute, ColumnStat](
|
||||||
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1),
|
min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1),
|
||||||
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1),
|
min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
|
||||||
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2),
|
min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
|
||||||
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4),
|
min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4),
|
||||||
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8),
|
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8),
|
||||||
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8),
|
min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8),
|
||||||
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4),
|
min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4),
|
||||||
AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1,
|
||||||
min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16),
|
min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16),
|
||||||
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1,
|
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1,
|
||||||
|
|
|
@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference}
|
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test row size estimation") {
|
test("test row size estimation") {
|
||||||
val dec1 = new java.math.BigDecimal("1.000000000000000000")
|
val dec1 = Decimal("1.000000000000000000")
|
||||||
val dec2 = new java.math.BigDecimal("8.000000000000000000")
|
val dec2 = Decimal("8.000000000000000000")
|
||||||
val d1 = Date.valueOf("2016-05-08")
|
val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
|
||||||
val d2 = Date.valueOf("2016-05-09")
|
val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09"))
|
||||||
val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
|
val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
|
||||||
val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
|
val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02"))
|
||||||
|
|
||||||
val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
|
val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
|
||||||
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
|
min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
|
||||||
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1),
|
min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
|
||||||
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2),
|
min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
|
||||||
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4),
|
min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4),
|
||||||
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
|
min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
|
||||||
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
|
min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
|
||||||
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4),
|
min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4),
|
||||||
AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
|
||||||
min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
|
min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
|
||||||
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
|
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
|
||||||
|
|
|
@ -73,10 +73,10 @@ case class AnalyzeColumnCommand(
|
||||||
val relation = sparkSession.table(tableIdent).logicalPlan
|
val relation = sparkSession.table(tableIdent).logicalPlan
|
||||||
// Resolve the column names and dedup using AttributeSet
|
// Resolve the column names and dedup using AttributeSet
|
||||||
val resolver = sparkSession.sessionState.conf.resolver
|
val resolver = sparkSession.sessionState.conf.resolver
|
||||||
val attributesToAnalyze = AttributeSet(columnNames.map { col =>
|
val attributesToAnalyze = columnNames.map { col =>
|
||||||
val exprOption = relation.output.find(attr => resolver(attr.name, col))
|
val exprOption = relation.output.find(attr => resolver(attr.name, col))
|
||||||
exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
|
exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
|
||||||
}).toSeq
|
}
|
||||||
|
|
||||||
// Make sure the column types are supported for stats gathering.
|
// Make sure the column types are supported for stats gathering.
|
||||||
attributesToAnalyze.foreach { attr =>
|
attributesToAnalyze.foreach { attr =>
|
||||||
|
@ -99,8 +99,8 @@ case class AnalyzeColumnCommand(
|
||||||
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
|
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
|
||||||
|
|
||||||
val rowCount = statsRow.getLong(0)
|
val rowCount = statsRow.getLong(0)
|
||||||
val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
|
val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) =>
|
||||||
(expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
|
(attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr))
|
||||||
}.toMap
|
}.toMap
|
||||||
(rowCount, columnStats)
|
(rowCount, columnStats)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import scala.util.Random
|
||||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||||
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
|
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||||
import org.apache.spark.sql.internal.StaticSQLConf
|
import org.apache.spark.sql.internal.StaticSQLConf
|
||||||
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
|
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
|
||||||
|
@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
|
||||||
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
|
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
|
||||||
stats.zip(df.schema).foreach { case ((k, v), field) =>
|
stats.zip(df.schema).foreach { case ((k, v), field) =>
|
||||||
withClue(s"column $k with type ${field.dataType}") {
|
withClue(s"column $k with type ${field.dataType}") {
|
||||||
val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
|
val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType))
|
||||||
assert(roundtrip == Some(v))
|
assert(roundtrip == Some(v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
|
||||||
/** A mapping from column to the stats collected. */
|
/** A mapping from column to the stats collected. */
|
||||||
protected val stats = mutable.LinkedHashMap(
|
protected val stats = mutable.LinkedHashMap(
|
||||||
"cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
|
"cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
|
||||||
"cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
|
"cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1),
|
||||||
"cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
|
"cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2),
|
||||||
"cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
|
"cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4),
|
||||||
"clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
|
"clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
|
||||||
"cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
|
"cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
|
||||||
"cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
|
"cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4),
|
||||||
"cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
|
"cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16),
|
||||||
"cstring" -> ColumnStat(2, None, None, 1, 3, 3),
|
"cstring" -> ColumnStat(2, None, None, 1, 3, 3),
|
||||||
"cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
|
"cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
|
||||||
"cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
|
"cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)),
|
||||||
"ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
|
Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4),
|
||||||
|
"ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)),
|
||||||
|
Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8)
|
||||||
)
|
)
|
||||||
|
|
||||||
private val randomName = new Random(31)
|
private val randomName = new Random(31)
|
||||||
|
|
|
@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
|
||||||
if (stats.rowCount.isDefined) {
|
if (stats.rowCount.isDefined) {
|
||||||
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
|
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
|
||||||
}
|
}
|
||||||
|
val colNameTypeMap: Map[String, DataType] =
|
||||||
|
tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap
|
||||||
stats.colStats.foreach { case (colName, colStat) =>
|
stats.colStats.foreach { case (colName, colStat) =>
|
||||||
colStat.toMap.foreach { case (k, v) =>
|
colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) =>
|
||||||
statsProperties += (columnStatKeyPropName(colName, k) -> v)
|
statsProperties += (columnStatKeyPropName(colName, k) -> v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue