[SPARK-18522][SQL] Explicit contract for column stats serialization

## What changes were proposed in this pull request?
The current implementation of column stats uses the base64 encoding of the internal UnsafeRow format to persist statistics (in table properties in Hive metastore). This is an internal format that is not stable across different versions of Spark and should NOT be used for persistence. In addition, it would be better if statistics stored in the catalog is human readable.

This pull request introduces the following changes:

1. Created a single ColumnStat class to for all data types. All data types track the same set of statistics.
2. Updated the implementation for stats collection to get rid of the dependency on internal data structures (e.g. InternalRow, or storing DateType as an int32). For example, previously dates were stored as a single integer, but are now stored as java.sql.Date. When we implement the next steps of CBO, we can add code to convert those back into internal types again.
3. Documented clearly what JVM data types are being used to store what data.
4. Defined a simple Map[String, String] interface for serializing and deserializing column stats into/from the catalog.
5. Rearranged the method/function structure so it is more clear what the supported data types are, and also moved how stats are generated into ColumnStat class so they are easy to find.

## How was this patch tested?
Removed most of the original test cases created for column statistics, and added three very simple ones to cover all the cases. The three test cases validate:
1. Roundtrip serialization works.
2. Behavior when analyzing non-existent column or unsupported data type column.
3. Result for stats collection for all valid data types.

Also moved parser related tests into a parser test suite and added an explicit serialization test for the Hive external catalog.

Author: Reynold Xin <rxin@databricks.com>

Closes #15959 from rxin/SPARK-18522.
This commit is contained in:
Reynold Xin 2016-11-23 20:48:41 +08:00 committed by Wenchen Fan
parent 9785ed40d7
commit 70ad07a9d2
9 changed files with 592 additions and 919 deletions

View file

@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.commons.codec.binary.Base64
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types._
/**
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
* corresponding statistic produced by the children. To override this behavior, override
@ -58,60 +61,175 @@ case class Statistics(
}
}
/**
* Statistics for a column.
* Statistics collected for a column.
*
* 1. Supported data types are defined in `ColumnStat.supportsType`.
* 2. The JVM data type stored in min/max is the external data type (used in Row) for the
* corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
* TimestampType we store java.sql.Timestamp.
* 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
* 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
* (sketches) might have been used, and the data collected can also be stale.
*
* @param distinctCount number of distinct values
* @param min minimum value
* @param max maximum value
* @param nullCount number of nulls
* @param avgLen average length of the values. For fixed-length types, this should be a constant.
* @param maxLen maximum length of the values. For fixed-length types, this should be a constant.
*/
case class ColumnStat(statRow: InternalRow) {
case class ColumnStat(
distinctCount: BigInt,
min: Option[Any],
max: Option[Any],
nullCount: BigInt,
avgLen: Long,
maxLen: Long) {
def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = {
NumericColumnStat(statRow, dataType)
}
def forString: StringColumnStat = StringColumnStat(statRow)
def forBinary: BinaryColumnStat = BinaryColumnStat(statRow)
def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow)
// We currently don't store min/max for binary/string type. This can change in the future and
// then we need to remove this require.
require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String]))
require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String]))
override def toString: String = {
// use Base64 for encoding
Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes)
/**
* Returns a map from string to string that can be used to serialize the column stats.
* The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
* representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
*
* As part of the protocol, the returned map always contains a key called "version".
* In the case min/max values are null (None), they won't appear in the map.
*/
def toMap: Map[String, String] = {
val map = new scala.collection.mutable.HashMap[String, String]
map.put(ColumnStat.KEY_VERSION, "1")
map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
map.toMap
}
}
object ColumnStat {
def apply(numFields: Int, str: String): ColumnStat = {
// use Base64 for decoding
val bytes = Base64.decodeBase64(str)
val unsafeRow = new UnsafeRow(numFields)
unsafeRow.pointTo(bytes, bytes.length)
ColumnStat(unsafeRow)
object ColumnStat extends Logging {
// List of string keys used to serialize ColumnStat
val KEY_VERSION = "version"
private val KEY_DISTINCT_COUNT = "distinctCount"
private val KEY_MIN_VALUE = "min"
private val KEY_MAX_VALUE = "max"
private val KEY_NULL_COUNT = "nullCount"
private val KEY_AVG_LEN = "avgLen"
private val KEY_MAX_LEN = "maxLen"
/** Returns true iff the we support gathering column statistics on column of the given type. */
def supportsType(dataType: DataType): Boolean = dataType match {
case _: IntegralType => true
case _: DecimalType => true
case DoubleType | FloatType => true
case BooleanType => true
case DateType => true
case TimestampType => true
case BinaryType | StringType => true
case _ => false
}
}
case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) {
// The indices here must be consistent with `ColumnStatStruct.numericColumnStat`.
val numNulls: Long = statRow.getLong(0)
val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType]
val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType]
val ndv: Long = statRow.getLong(3)
}
/**
* Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
* from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
*/
def fromMap(table: String, field: StructField, map: Map[String, String])
: Option[ColumnStat] = {
val str2val: (String => Any) = field.dataType match {
case _: IntegralType => _.toLong
case _: DecimalType => new java.math.BigDecimal(_)
case DoubleType | FloatType => _.toDouble
case BooleanType => _.toBoolean
case DateType => java.sql.Date.valueOf
case TimestampType => java.sql.Timestamp.valueOf
// This version of Spark does not use min/max for binary/string types so we ignore it.
case BinaryType | StringType => _ => null
case _ =>
throw new AnalysisException("Column statistics deserialization is not supported for " +
s"column ${field.name} of data type: ${field.dataType}.")
}
case class StringColumnStat(statRow: InternalRow) {
// The indices here must be consistent with `ColumnStatStruct.stringColumnStat`.
val numNulls: Long = statRow.getLong(0)
val avgColLen: Double = statRow.getDouble(1)
val maxColLen: Long = statRow.getInt(2)
val ndv: Long = statRow.getLong(3)
}
try {
Some(ColumnStat(
distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
// Note that flatMap(Option.apply) turns Option(null) into None.
min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
))
} catch {
case NonFatal(e) =>
logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e)
None
}
}
case class BinaryColumnStat(statRow: InternalRow) {
// The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`.
val numNulls: Long = statRow.getLong(0)
val avgColLen: Double = statRow.getDouble(1)
val maxColLen: Long = statRow.getInt(2)
}
/**
* Constructs an expression to compute column statistics for a given column.
*
* The expression should create a single struct column with the following schema:
* distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long
*
* Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and
* as a result should stay in sync with it.
*/
def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = {
def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
})
val one = Literal(1, LongType)
// the approximate ndv (num distinct value) should never be larger than the number of rows
val numNonNulls = if (col.nullable) Count(col) else Count(one)
val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls))
val numNulls = Subtract(Count(one), numNonNulls)
def fixedLenTypeStruct(castType: DataType) = {
// For fixed width types, avg size should be the same as max size.
val avgSize = Literal(col.dataType.defaultSize, LongType)
struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize)
}
col.dataType match {
case _: IntegralType => fixedLenTypeStruct(LongType)
case _: DecimalType => fixedLenTypeStruct(col.dataType)
case DoubleType | FloatType => fixedLenTypeStruct(DoubleType)
case BooleanType => fixedLenTypeStruct(col.dataType)
case DateType => fixedLenTypeStruct(col.dataType)
case TimestampType => fixedLenTypeStruct(col.dataType)
case BinaryType | StringType =>
// For string and binary type, we don't store min/max.
val nullLit = Literal(null, col.dataType)
struct(
ndv, nullLit, nullLit, numNulls,
Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType))
case _ =>
throw new AnalysisException("Analyzing column statistics is not supported for column " +
s"${col.name} of data type: ${col.dataType}.")
}
}
/** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
def rowToColumnStat(row: Row): ColumnStat = {
ColumnStat(
distinctCount = BigInt(row.getLong(0)),
min = Option(row.get(1)), // for string/binary min/max, get should return null
max = Option(row.get(2)),
nullCount = BigInt(row.getLong(3)),
avgLen = row.getLong(4),
maxLen = row.getLong(5)
)
}
case class BooleanColumnStat(statRow: InternalRow) {
// The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`.
val numNulls: Long = statRow.getLong(0)
val numTrues: Long = statRow.getLong(1)
val numFalses: Long = statRow.getLong(2)
}

View file

@ -24,9 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types._
/**
@ -62,7 +61,7 @@ case class AnalyzeColumnCommand(
// Compute stats for each column
val (rowCount, newColStats) =
AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames)
AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
@ -88,8 +87,9 @@ object AnalyzeColumnCommand extends Logging {
*
* This is visible for testing.
*/
def computeColStats(
def computeColumnStats(
sparkSession: SparkSession,
tableName: String,
relation: LogicalPlan,
columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = {
@ -97,102 +97,33 @@ object AnalyzeColumnCommand extends Logging {
val resolver = sparkSession.sessionState.conf.resolver
val attributesToAnalyze = AttributeSet(columnNames.map { col =>
val exprOption = relation.output.find(attr => resolver(attr.name, col))
exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col."))
exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
}).toSeq
// Make sure the column types are supported for stats gathering.
attributesToAnalyze.foreach { attr =>
if (!ColumnStat.supportsType(attr.dataType)) {
throw new AnalysisException(
s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " +
"and Spark does not support statistics collection on this column type.")
}
}
// Collect statistics per column.
// The first element in the result will be the overall row count, the following elements
// will be structs containing all column stats.
// The layout of each struct follows the layout of the ColumnStats.
val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError
val expressions = Count(Literal(1)).toAggregateExpression() +:
attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr))
val namedExpressions = expressions.map(e => Alias(e, e.toString)())
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation))
.queryExecution.toRdd.collect().head
attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr))
val namedExpressions = expressions.map(e => Alias(e, e.toString)())
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
// unwrap the result
// TODO: Get rid of numFields by using the public Dataset API.
val rowCount = statsRow.getLong(0)
val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType)
(expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields)))
(expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
}.toMap
(rowCount, columnStats)
}
private val zero = Literal(0, LongType)
private val one = Literal(1, LongType)
private def numNulls(e: Expression): Expression = {
if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
}
private def max(e: Expression): Expression = Max(e)
private def min(e: Expression): Expression = Min(e)
private def ndv(e: Expression, relativeSD: Double): Expression = {
// the approximate ndv should never be larger than the number of rows
Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
}
private def avgLength(e: Expression): Expression = Average(Length(e))
private def maxLength(e: Expression): Expression = Max(Length(e))
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
/**
* Creates a struct that groups the sequence of expressions together. This is used to create
* one top level struct per column.
*/
private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = {
CreateStruct(exprs.map { expr: Expression =>
expr.transformUp {
case af: AggregateFunction => af.toAggregateExpression()
}
})
}
private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
}
private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
}
private def binaryColumnStat(e: Expression): Seq[Expression] = {
Seq(numNulls(e), avgLength(e), maxLength(e))
}
private def booleanColumnStat(e: Expression): Seq[Expression] = {
Seq(numNulls(e), numTrues(e), numFalses(e))
}
// TODO(rxin): Get rid of this function.
def numStatFields(dataType: DataType): Int = {
dataType match {
case BinaryType | BooleanType => 3
case _ => 4
}
}
/**
* Creates a struct expression that contains the statistics to collect for a column.
*
* @param attr column to collect statistics
* @param relativeSD relative error for approximate number of distinct values.
*/
def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = {
attr.dataType match {
case _: NumericType | TimestampType | DateType =>
createStruct(numericColumnStat(attr, relativeSD))
case StringType =>
createStruct(stringColumnStat(attr, relativeSD))
case BinaryType =>
createStruct(binaryColumnStat(attr))
case BooleanType =>
createStruct(booleanColumnStat(attr))
case otherType =>
throw new AnalysisException("Analyzing columns is not supported for column " +
s"${attr.name} of data type: ${attr.dataType}.")
}
}
}

View file

@ -0,0 +1,218 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.{lang => jl}
import java.sql.{Date, Timestamp}
import scala.collection.mutable
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.test.SQLTestData.ArrayData
import org.apache.spark.sql.types._
/**
* End-to-end suite testing statistics collection and use on both entire table and columns.
*/
class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext {
import testImplicits._
private def checkTableStats(tableName: String, expectedRowCount: Option[Int])
: Option[Statistics] = {
val df = spark.table(tableName)
val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
rel.catalogTable.get.stats
}
assert(stats.size == 1)
stats.head
}
test("estimates the size of a limit 0 on outer join") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
val df1 = spark.table("test")
val df2 = spark.table("test").limit(0)
val df = df1.join(df2, Seq("k"), "left")
val sizes = df.queryExecution.analyzed.collect { case g: Join =>
g.statistics.sizeInBytes
}
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
assert(sizes.head === BigInt(96),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}
test("analyze column command - unsupported types and invalid columns") {
val tableName = "column_stats_test1"
withTable(tableName) {
Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
// Test unsupported data types
val err1 = intercept[AnalysisException] {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
}
assert(err1.message.contains("does not support statistics collection"))
// Test invalid columns
val err2 = intercept[AnalysisException] {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column")
}
assert(err2.message.contains("does not exist"))
}
}
test("test table-level statistics for data source table") {
val tableName = "tbl"
withTable(tableName) {
sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
// noscan won't count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
checkTableStats(tableName, expectedRowCount = None)
// without noscan, we count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
checkTableStats(tableName, expectedRowCount = Some(2))
}
}
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.statistics.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
test("estimates the size of limit") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
val df = sql(s"""SELECT * FROM test limit $limit""")
val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
g.statistics.sizeInBytes
}
assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesGlobalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
l.statistics.sizeInBytes
}
assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesLocalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
}
}
}
}
/**
* The base for test cases that we want to include in both the hive module (for verifying behavior
* when using the Hive external catalog) as well as in the sql/core module.
*/
abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils {
import testImplicits._
private val dec1 = new java.math.BigDecimal("1.000000000000000000")
private val dec2 = new java.math.BigDecimal("8.000000000000000000")
private val d1 = Date.valueOf("2016-05-08")
private val d2 = Date.valueOf("2016-05-09")
private val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
private val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
/**
* Define a very simple 3 row table used for testing column serialization.
* Note: last column is seq[int] which doesn't support stats collection.
*/
protected val data = Seq[
(jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long,
jl.Double, jl.Float, java.math.BigDecimal,
String, Array[Byte], Date, Timestamp,
Seq[Int])](
(false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null),
(true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null),
(null, null, null, null, null, null, null, null, null, null, null, null, null)
)
/** A mapping from column to the stats collected. */
protected val stats = mutable.LinkedHashMap(
"cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
"cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
"cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
"cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
"clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
"cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
"cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
"cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
"cstring" -> ColumnStat(2, None, None, 1, 3, 3),
"cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
"cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
"ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
)
test("column stats round trip serialization") {
// Make sure we serialize and then deserialize and we will get the result data
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
stats.zip(df.schema).foreach { case ((k, v), field) =>
withClue(s"column $k with type ${field.dataType}") {
val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
assert(roundtrip == Some(v))
}
}
}
test("analyze column command - result verification") {
val tableName = "column_stats_test2"
// (data.head.productArity - 1) because the last column does not support stats collection.
assert(stats.size == data.head.productArity - 1)
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
withTable(tableName) {
df.write.saveAsTable(tableName)
// Collect statistics
sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", "))
// Validate statistics
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
assert(table.stats.isDefined)
assert(table.stats.get.colStats.size == stats.size)
stats.foreach { case (k, v) =>
withClue(s"column $k") {
assert(table.stats.get.colStats(k) == v)
}
}
}
}
}

View file

@ -1,334 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
import org.apache.spark.sql.test.SQLTestData.ArrayData
import org.apache.spark.sql.types._
class StatisticsColumnSuite extends StatisticsTest {
import testImplicits._
test("parse analyze column commands") {
val tableName = "tbl"
// we need to specify column names
intercept[ParseException] {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS")
}
val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value"
val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql)
val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value"))
comparePlans(parsed, expected)
}
test("analyzing columns of non-atomic types is not supported") {
val tableName = "tbl"
withTable(tableName) {
Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
val err = intercept[AnalysisException] {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
}
assert(err.message.contains("Analyzing columns is not supported"))
}
}
test("check correctness of columns") {
val table = "tbl"
val colName1 = "abc"
val colName2 = "x.yz"
withTable(table) {
sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET")
val invalidColError = intercept[AnalysisException] {
sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key")
}
assert(invalidColError.message == "Invalid column name: key.")
withSQLConf("spark.sql.caseSensitive" -> "true") {
val invalidErr = intercept[AnalysisException] {
sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}")
}
assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.")
}
withSQLConf("spark.sql.caseSensitive" -> "false") {
val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2)
val tableIdent = TableIdentifier(table, Some("default"))
val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
val (_, columnStats) =
AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze)
assert(columnStats.contains(colName1))
assert(columnStats.contains(colName2))
// check deduplication
assert(columnStats.size == 2)
assert(!columnStats.contains(colName2.toUpperCase))
}
}
}
private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = {
values.filter(_.isDefined).map(_.get)
}
test("column-level statistics for integral type columns") {
val values = (0 to 5).map { i =>
if (i % 2 == 0) None else Some(i)
}
val data = values.map { i =>
(i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong))
}
val df = data.toDF("c1", "c2", "c3", "c4")
val nonNullValues = getNonNullValues[Int](values)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = ColumnStat(InternalRow(
values.count(_.isEmpty).toLong,
nonNullValues.max,
nonNullValues.min,
nonNullValues.distinct.length.toLong))
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for fractional type columns") {
val values: Seq[Option[Decimal]] = (0 to 5).map { i =>
if (i == 0) None else Some(Decimal(i + i * 0.01))
}
val data = values.map { i =>
(i.map(_.toFloat), i.map(_.toDouble), i)
}
val df = data.toDF("c1", "c2", "c3")
val nonNullValues = getNonNullValues[Decimal](values)
val numNulls = values.count(_.isEmpty).toLong
val ndv = nonNullValues.distinct.length.toLong
val expectedColStatsSeq = df.schema.map { f =>
val colStat = f.dataType match {
case floatType: FloatType =>
ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat,
ndv))
case doubleType: DoubleType =>
ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble,
ndv))
case decimalType: DecimalType =>
ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv))
}
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for string column") {
val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some(""))
val df = values.toDF("c1")
val nonNullValues = getNonNullValues[String](values)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = ColumnStat(InternalRow(
values.count(_.isEmpty).toLong,
nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
nonNullValues.map(_.length).max.toInt,
nonNullValues.distinct.length.toLong))
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for binary column") {
val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes))
val df = values.toDF("c1")
val nonNullValues = getNonNullValues[Array[Byte]](values)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = ColumnStat(InternalRow(
values.count(_.isEmpty).toLong,
nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
nonNullValues.map(_.length).max.toInt))
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for boolean column") {
val values = Seq(None, Some(true), Some(false), Some(true))
val df = values.toDF("c1")
val nonNullValues = getNonNullValues[Boolean](values)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = ColumnStat(InternalRow(
values.count(_.isEmpty).toLong,
nonNullValues.count(_.equals(true)).toLong,
nonNullValues.count(_.equals(false)).toLong))
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for date column") {
val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf))
val df = values.toDF("c1")
val nonNullValues = getNonNullValues[Date](values)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = ColumnStat(InternalRow(
values.count(_.isEmpty).toLong,
// Internally, DateType is represented as the number of days from 1970-01-01.
nonNullValues.map(DateTimeUtils.fromJavaDate).max,
nonNullValues.map(DateTimeUtils.fromJavaDate).min,
nonNullValues.distinct.length.toLong))
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for timestamp column") {
val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i =>
i.map(Timestamp.valueOf)
}
val df = values.toDF("c1")
val nonNullValues = getNonNullValues[Timestamp](values)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = ColumnStat(InternalRow(
values.count(_.isEmpty).toLong,
// Internally, TimestampType is represented as the number of days from 1970-01-01
nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max,
nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min,
nonNullValues.distinct.length.toLong))
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for null columns") {
val values = Seq(None, None)
val data = values.map { i =>
(i.map(_.toString), i.map(_.toString.toInt))
}
val df = data.toDF("c1", "c2")
val expectedColStatsSeq = df.schema.map { f =>
(f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L)))
}
checkColStats(df, expectedColStatsSeq)
}
test("column-level statistics for columns with different types") {
val intSeq = Seq(1, 2)
val doubleSeq = Seq(1.01d, 2.02d)
val stringSeq = Seq("a", "bb")
val binarySeq = Seq("a", "bb").map(_.getBytes)
val booleanSeq = Seq(true, false)
val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf)
val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf)
val longSeq = Seq(5L, 4L)
val data = intSeq.indices.map { i =>
(intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i),
timestampSeq(i), longSeq(i))
}
val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8")
val expectedColStatsSeq = df.schema.map { f =>
val colStat = f.dataType match {
case IntegerType =>
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
case DoubleType =>
ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min,
doubleSeq.distinct.length.toLong))
case StringType =>
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
case BinaryType =>
ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
binarySeq.map(_.length).max.toInt))
case BooleanType =>
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
booleanSeq.count(_.equals(false)).toLong))
case DateType =>
ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max,
dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong))
case TimestampType =>
ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max,
timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min,
timestampSeq.distinct.length.toLong))
case LongType =>
ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong))
}
(f, colStat)
}
checkColStats(df, expectedColStatsSeq)
}
test("update table-level stats while collecting column-level stats") {
val table = "tbl"
withTable(table) {
sql(s"CREATE TABLE $table (c1 int) USING PARQUET")
sql(s"INSERT INTO $table SELECT 1")
sql(s"ANALYZE TABLE $table COMPUTE STATISTICS")
checkTableStats(tableName = table, expectedRowCount = Some(1))
// update table-level stats between analyze table and analyze column commands
sql(s"INSERT INTO $table SELECT 1")
sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2))
val colStat = fetchedStats.get.colStats("c1")
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = colStat,
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
}
}
test("analyze column stats independently") {
val table = "tbl"
withTable(table) {
sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET")
sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0))
assert(fetchedStats1.get.colStats.size == 1)
val expected1 = ColumnStat(InternalRow(0L, null, null, 0L))
val rsd = spark.sessionState.conf.ndvMaxError
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = fetchedStats1.get.colStats("c1"),
expectedColStat = expected1,
rsd = rsd)
sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2")
val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0))
// column c1 is kept in the stats
assert(fetchedStats2.get.colStats.size == 2)
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = fetchedStats2.get.colStats("c1"),
expectedColStat = expected1,
rsd = rsd)
val expected2 = ColumnStat(InternalRow(0L, null, null, 0L))
StatisticsTest.checkColStat(
dataType = LongType,
colStat = fetchedStats2.get.colStats("c2"),
expectedColStat = expected2,
rsd = rsd)
}
}
}

View file

@ -1,92 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
import org.apache.spark.sql.types._
class StatisticsSuite extends StatisticsTest {
import testImplicits._
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.statistics.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
test("estimates the size of limit") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
val df = sql(s"""SELECT * FROM test limit $limit""")
val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
g.statistics.sizeInBytes
}
assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesGlobalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
l.statistics.sizeInBytes
}
assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesLocalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
}
}
}
test("estimates the size of a limit 0 on outer join") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
val df1 = spark.table("test")
val df2 = spark.table("test").limit(0)
val df = df1.join(df2, Seq("k"), "left")
val sizes = df.queryExecution.analyzed.collect { case g: Join =>
g.statistics.sizeInBytes
}
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
assert(sizes.head === BigInt(96),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}
test("test table-level statistics for data source table created in InMemoryCatalog") {
val tableName = "tbl"
withTable(tableName) {
sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
// noscan won't count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
checkTableStats(tableName, expectedRowCount = None)
// without noscan, we count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
checkTableStats(tableName, expectedRowCount = Some(2))
}
}
}

View file

@ -1,130 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
trait StatisticsTest extends QueryTest with SharedSQLContext {
def checkColStats(
df: DataFrame,
expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
val table = "tbl"
withTable(table) {
df.write.format("json").saveAsTable(table)
val columns = expectedColStatsSeq.map(_._1)
val tableIdent = TableIdentifier(table, Some("default"))
val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
val (_, columnStats) =
AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name))
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
assert(columnStats.contains(field.name))
val colStat = columnStats(field.name)
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = colStat,
expectedColStat = expectedColStat,
rsd = spark.sessionState.conf.ndvMaxError)
// check if we get the same colStat after encoding and decoding
val encodedCS = colStat.toString
val numFields = AnalyzeColumnCommand.numStatFields(field.dataType)
val decodedCS = ColumnStat(numFields, encodedCS)
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = decodedCS,
expectedColStat = expectedColStat,
rsd = spark.sessionState.conf.ndvMaxError)
}
}
}
def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = {
val df = spark.table(tableName)
val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
rel.catalogTable.get.stats
}
assert(stats.size == 1)
stats.head
}
}
object StatisticsTest {
def checkColStat(
dataType: DataType,
colStat: ColumnStat,
expectedColStat: ColumnStat,
rsd: Double): Unit = {
dataType match {
case StringType =>
val cs = colStat.forString
val expectedCS = expectedColStat.forString
assert(cs.numNulls == expectedCS.numNulls)
assert(cs.avgColLen == expectedCS.avgColLen)
assert(cs.maxColLen == expectedCS.maxColLen)
checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
case BinaryType =>
val cs = colStat.forBinary
val expectedCS = expectedColStat.forBinary
assert(cs.numNulls == expectedCS.numNulls)
assert(cs.avgColLen == expectedCS.avgColLen)
assert(cs.maxColLen == expectedCS.maxColLen)
case BooleanType =>
val cs = colStat.forBoolean
val expectedCS = expectedColStat.forBoolean
assert(cs.numNulls == expectedCS.numNulls)
assert(cs.numTrues == expectedCS.numTrues)
assert(cs.numFalses == expectedCS.numFalses)
case atomicType: AtomicType =>
checkNumericColStats(
dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd)
}
}
private def checkNumericColStats(
dataType: AtomicType,
colStat: ColumnStat,
expectedColStat: ColumnStat,
rsd: Double): Unit = {
val cs = colStat.forNumeric(dataType)
val expectedCS = expectedColStat.forNumeric(dataType)
assert(cs.numNulls == expectedCS.numNulls)
assert(cs.max == expectedCS.max)
assert(cs.min == expectedCS.min)
checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
}
private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = {
// ndv is an approximate value, so we make sure we have the value, and it should be
// within 3*SD's of the given rsd.
if (expectedNdv == 0) {
assert(ndv == 0)
} else if (expectedNdv > 0) {
assert(ndv > 0)
val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.")
}
}
}

View file

@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand,
DescribeTableCommand, ShowFunctionsCommand}
import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
@ -221,12 +220,22 @@ class SparkSqlParserSuite extends PlanTest {
intercept("explain describe tables x", "Unsupported SQL statement")
}
test("SPARK-18106 analyze table") {
test("analyze table statistics") {
assertEqual("analyze table t compute statistics",
AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
assertEqual("analyze table t compute statistics noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
assertEqual("analyze table t partition (a) compute statistics noscan",
assertEqual("analyze table t partition (a) compute statistics nOscAn",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
// Partitions specified - we currently parse them but don't do anything with it
assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS",
AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
intercept("analyze table t compute statistics xxxx",
@ -234,4 +243,11 @@ class SparkSqlParserSuite extends PlanTest {
intercept("analyze table t partition (a) compute statistics xxxx",
"Expected `NOSCAN` instead of `xxxx`")
}
test("analyze table column statistics") {
intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "")
assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value",
AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
}
}

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.StaticSQLConf._
@ -514,7 +514,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
}
stats.colStats.foreach { case (colName, colStat) =>
statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString
colStat.toMap.foreach { case (k, v) =>
statsProperties += (columnStatKeyPropName(colName, k) -> v)
}
}
tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties)
} else {
@ -605,48 +607,65 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* It reads table schema, provider, partition column names and bucket specification from table
* properties, and filter out these special entries from table properties.
*/
private def restoreTableMetadata(table: CatalogTable): CatalogTable = {
private def restoreTableMetadata(inputTable: CatalogTable): CatalogTable = {
if (conf.get(DEBUG_MODE)) {
return table
return inputTable
}
val tableWithSchema = if (table.tableType == VIEW) {
table
} else {
getProviderFromTableProperties(table) match {
var table = inputTable
if (table.tableType != VIEW) {
table.properties.get(DATASOURCE_PROVIDER) match {
// No provider in table properties, which means this table is created by Spark prior to 2.1,
// or is created at Hive side.
case None =>
table.copy(provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true)
table = table.copy(
provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true)
// This is a Hive serde table created by Spark 2.1 or higher versions.
case Some(DDLUtils.HIVE_PROVIDER) => restoreHiveSerdeTable(table)
case Some(DDLUtils.HIVE_PROVIDER) =>
table = restoreHiveSerdeTable(table)
// This is a regular data source table.
case Some(provider) => restoreDataSourceTable(table, provider)
case Some(provider) =>
table = restoreDataSourceTable(table, provider)
}
}
// construct Spark's statistics from information in Hive metastore
val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
val tableWithStats = if (statsProps.nonEmpty) {
val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX))
.map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) }
val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect {
case f if colStatsProps.contains(f.name) =>
val numFields = AnalyzeColumnCommand.numStatFields(f.dataType)
(f.name, ColumnStat(numFields, colStatsProps(f.name)))
}.toMap
tableWithSchema.copy(
val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
if (statsProps.nonEmpty) {
val colStats = new scala.collection.mutable.HashMap[String, ColumnStat]
// For each column, recover its column stats. Note that this is currently a O(n^2) operation,
// but given the number of columns it usually not enormous, this is probably OK as a start.
// If we want to map this a linear operation, we'd need a stronger contract between the
// naming convention used for serialization.
table.schema.foreach { field =>
if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) {
// If "version" field is defined, then the column stat is defined.
val keyPrefix = columnStatKeyPropName(field.name, "")
val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) =>
(k.drop(keyPrefix.length), v)
}
ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach {
colStat => colStats += field.name -> colStat
}
}
}
table = table.copy(
stats = Some(Statistics(
sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)),
rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
colStats = colStats)))
} else {
tableWithSchema
sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)),
rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
colStats = colStats.toMap)))
}
tableWithStats.copy(properties = getOriginalTableProperties(table))
// Get the original table properties as defined by the user.
table.copy(
properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) })
}
private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = {
@ -1020,17 +1039,17 @@ object HiveExternalCatalog {
val TABLE_PARTITION_PROVIDER_CATALOG = "catalog"
val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem"
def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = {
metadata.properties.get(DATASOURCE_PROVIDER)
}
def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = {
metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }
/**
* Returns the fully qualified name used in table properties for a particular column stat.
* For example, for column "mycol", and "min" stat, this should return
* "spark.sql.statistics.colStats.mycol.min".
*/
private def columnStatKeyPropName(columnName: String, statKey: String): String = {
STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey
}
// A persisted data source table always store its schema in the catalog.
def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
val errorMessage = "Could not read schema from the hive metastore because it is corrupted."
val props = metadata.properties
val schema = props.get(DATASOURCE_SCHEMA)
@ -1078,11 +1097,11 @@ object HiveExternalCatalog {
)
}
def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
private def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
getColumnNamesByType(metadata.properties, "part", "partitioning columns")
}
def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
private def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets =>
BucketSpec(
numBuckets.toInt,

View file

@ -22,56 +22,16 @@ import java.io.{File, PrintWriter}
import scala.reflect.ClassTag
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand)
val operators = parsed.collect {
case a: AnalyzeTableCommand => a
case o => o
}
assert(operators.size === 1)
if (operators(0).getClass() != c) {
fail(
s"""$analyzeCommand expected command: $c, but got ${operators(0)}
|parsed command:
|$parsed
""".stripMargin)
}
}
assertAnalyzeCommand(
"ANALYZE TABLE Table1 COMPUTE STATISTICS",
classOf[AnalyzeTableCommand])
assertAnalyzeCommand(
"ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
classOf[AnalyzeTableCommand])
assertAnalyzeCommand(
"ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
classOf[AnalyzeTableCommand])
assertAnalyzeCommand(
"ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS",
classOf[AnalyzeTableCommand])
assertAnalyzeCommand(
"ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan",
classOf[AnalyzeTableCommand])
assertAnalyzeCommand(
"ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn",
classOf[AnalyzeTableCommand])
}
class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton {
test("MetastoreRelations fallback to HDFS for size estimation") {
val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled
@ -310,6 +270,110 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}
test("verify serialized column stats after analyzing columns") {
import testImplicits._
val tableName = "column_stats_test2"
// (data.head.productArity - 1) because the last column does not support stats collection.
assert(stats.size == data.head.productArity - 1)
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
withTable(tableName) {
df.write.saveAsTable(tableName)
// Collect statistics
sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", "))
// Validate statistics
val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
val table = hiveClient.getTable("default", tableName)
val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats"))
assert(props == Map(
"spark.sql.statistics.colStats.cbinary.avgLen" -> "3",
"spark.sql.statistics.colStats.cbinary.distinctCount" -> "2",
"spark.sql.statistics.colStats.cbinary.maxLen" -> "3",
"spark.sql.statistics.colStats.cbinary.nullCount" -> "1",
"spark.sql.statistics.colStats.cbinary.version" -> "1",
"spark.sql.statistics.colStats.cbool.avgLen" -> "1",
"spark.sql.statistics.colStats.cbool.distinctCount" -> "2",
"spark.sql.statistics.colStats.cbool.max" -> "true",
"spark.sql.statistics.colStats.cbool.maxLen" -> "1",
"spark.sql.statistics.colStats.cbool.min" -> "false",
"spark.sql.statistics.colStats.cbool.nullCount" -> "1",
"spark.sql.statistics.colStats.cbool.version" -> "1",
"spark.sql.statistics.colStats.cbyte.avgLen" -> "1",
"spark.sql.statistics.colStats.cbyte.distinctCount" -> "2",
"spark.sql.statistics.colStats.cbyte.max" -> "2",
"spark.sql.statistics.colStats.cbyte.maxLen" -> "1",
"spark.sql.statistics.colStats.cbyte.min" -> "1",
"spark.sql.statistics.colStats.cbyte.nullCount" -> "1",
"spark.sql.statistics.colStats.cbyte.version" -> "1",
"spark.sql.statistics.colStats.cdate.avgLen" -> "4",
"spark.sql.statistics.colStats.cdate.distinctCount" -> "2",
"spark.sql.statistics.colStats.cdate.max" -> "2016-05-09",
"spark.sql.statistics.colStats.cdate.maxLen" -> "4",
"spark.sql.statistics.colStats.cdate.min" -> "2016-05-08",
"spark.sql.statistics.colStats.cdate.nullCount" -> "1",
"spark.sql.statistics.colStats.cdate.version" -> "1",
"spark.sql.statistics.colStats.cdecimal.avgLen" -> "16",
"spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2",
"spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000",
"spark.sql.statistics.colStats.cdecimal.maxLen" -> "16",
"spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000",
"spark.sql.statistics.colStats.cdecimal.nullCount" -> "1",
"spark.sql.statistics.colStats.cdecimal.version" -> "1",
"spark.sql.statistics.colStats.cdouble.avgLen" -> "8",
"spark.sql.statistics.colStats.cdouble.distinctCount" -> "2",
"spark.sql.statistics.colStats.cdouble.max" -> "6.0",
"spark.sql.statistics.colStats.cdouble.maxLen" -> "8",
"spark.sql.statistics.colStats.cdouble.min" -> "1.0",
"spark.sql.statistics.colStats.cdouble.nullCount" -> "1",
"spark.sql.statistics.colStats.cdouble.version" -> "1",
"spark.sql.statistics.colStats.cfloat.avgLen" -> "4",
"spark.sql.statistics.colStats.cfloat.distinctCount" -> "2",
"spark.sql.statistics.colStats.cfloat.max" -> "7.0",
"spark.sql.statistics.colStats.cfloat.maxLen" -> "4",
"spark.sql.statistics.colStats.cfloat.min" -> "1.0",
"spark.sql.statistics.colStats.cfloat.nullCount" -> "1",
"spark.sql.statistics.colStats.cfloat.version" -> "1",
"spark.sql.statistics.colStats.cint.avgLen" -> "4",
"spark.sql.statistics.colStats.cint.distinctCount" -> "2",
"spark.sql.statistics.colStats.cint.max" -> "4",
"spark.sql.statistics.colStats.cint.maxLen" -> "4",
"spark.sql.statistics.colStats.cint.min" -> "1",
"spark.sql.statistics.colStats.cint.nullCount" -> "1",
"spark.sql.statistics.colStats.cint.version" -> "1",
"spark.sql.statistics.colStats.clong.avgLen" -> "8",
"spark.sql.statistics.colStats.clong.distinctCount" -> "2",
"spark.sql.statistics.colStats.clong.max" -> "5",
"spark.sql.statistics.colStats.clong.maxLen" -> "8",
"spark.sql.statistics.colStats.clong.min" -> "1",
"spark.sql.statistics.colStats.clong.nullCount" -> "1",
"spark.sql.statistics.colStats.clong.version" -> "1",
"spark.sql.statistics.colStats.cshort.avgLen" -> "2",
"spark.sql.statistics.colStats.cshort.distinctCount" -> "2",
"spark.sql.statistics.colStats.cshort.max" -> "3",
"spark.sql.statistics.colStats.cshort.maxLen" -> "2",
"spark.sql.statistics.colStats.cshort.min" -> "1",
"spark.sql.statistics.colStats.cshort.nullCount" -> "1",
"spark.sql.statistics.colStats.cshort.version" -> "1",
"spark.sql.statistics.colStats.cstring.avgLen" -> "3",
"spark.sql.statistics.colStats.cstring.distinctCount" -> "2",
"spark.sql.statistics.colStats.cstring.maxLen" -> "3",
"spark.sql.statistics.colStats.cstring.nullCount" -> "1",
"spark.sql.statistics.colStats.cstring.version" -> "1",
"spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8",
"spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2",
"spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0",
"spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8",
"spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0",
"spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1",
"spark.sql.statistics.colStats.ctimestamp.version" -> "1"
))
}
}
private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = {
test("test table-level statistics for " + tableDescription) {
val parquetTable = "parquetTable"
@ -319,7 +383,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
TableIdentifier(parquetTable))
assert(DDLUtils.isDatasourceTable(catalogTable))
sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src")
// Add a filter to avoid creating too many partitions
sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10")
checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None)
@ -328,7 +393,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
val fetchedStats1 = checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None)
sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src")
sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10")
sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan")
val fetchedStats2 = checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None)
@ -340,7 +405,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
parquetTable,
isDataSourceTable = true,
hasSizeInBytes = true,
expectedRowCounts = Some(1000))
expectedRowCounts = Some(20))
assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes)
}
}
@ -369,6 +434,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}
/** Used to test refreshing cached metadata once table stats are updated. */
private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = {
val tableName = "tbl"
var statsBeforeUpdate: Statistics = null
@ -411,145 +477,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
assert(statsAfterUpdate.rowCount == Some(2))
}
test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") {
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true)
assert(statsBeforeUpdate.sizeInBytes > 0)
assert(statsBeforeUpdate.rowCount == Some(1))
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = statsBeforeUpdate.colStats("key"),
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
assert(statsAfterUpdate.rowCount == Some(2))
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = statsAfterUpdate.colStats("key"),
expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
rsd = spark.sessionState.conf.ndvMaxError)
}
private lazy val (testDataFrame, expectedColStatsSeq) = {
import testImplicits._
val intSeq = Seq(1, 2)
val stringSeq = Seq("a", "bb")
val binarySeq = Seq("a", "bb").map(_.getBytes)
val booleanSeq = Seq(true, false)
val data = intSeq.indices.map { i =>
(intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
}
val df: DataFrame = data.toDF("c1", "c2", "c3", "c4")
val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f =>
val colStat = f.dataType match {
case IntegerType =>
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
case StringType =>
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
case BinaryType =>
ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
binarySeq.map(_.length).max.toInt))
case BooleanType =>
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
booleanSeq.count(_.equals(false)).toLong))
}
(f, colStat)
}
(df, expectedColStatsSeq)
}
private def checkColStats(
tableName: String,
isDataSourceTable: Boolean,
expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
val readback = spark.table(tableName)
val stats = readback.queryExecution.analyzed.collect {
case rel: MetastoreRelation =>
assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table")
rel.catalogTable.stats.get
case rel: LogicalRelation =>
assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table")
rel.catalogTable.get.stats.get
}
assert(stats.length == 1)
val columnStats = stats.head.colStats
assert(columnStats.size == expectedColStatsSeq.length)
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = columnStats(field.name),
expectedColStat = expectedColStat,
rsd = spark.sessionState.conf.ndvMaxError)
}
}
test("generate and load column-level stats for data source table") {
val dsTable = "dsTable"
withTable(dsTable) {
testDataFrame.write.format("parquet").saveAsTable(dsTable)
sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
}
}
test("generate and load column-level stats for hive serde table") {
val hTable = "hTable"
val tmp = "tmp"
withTable(hTable, tmp) {
testDataFrame.write.format("parquet").saveAsTable(tmp)
sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE")
sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
}
}
// When caseSensitive is on, for columns with only case difference, they are different columns
// and we should generate column stats for all of them.
private def checkCaseSensitiveColStats(columnName: String): Unit = {
val tableName = "tbl"
withTable(tableName) {
val column1 = columnName.toLowerCase
val column2 = columnName.toUpperCase
withSQLConf("spark.sql.caseSensitive" -> "true") {
sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET")
sql(s"INSERT INTO $tableName SELECT 1, 3.0")
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`")
val readback = spark.table(tableName)
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
val columnStats = rel.catalogTable.get.stats.get.colStats
assert(columnStats.size == 2)
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = columnStats(column1),
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
StatisticsTest.checkColStat(
dataType = DoubleType,
colStat = columnStats(column2),
expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
rel
}
assert(relations.size == 1)
}
}
}
test("check column statistics for case sensitive column names") {
checkCaseSensitiveColStats(columnName = "c1")
}
test("check column statistics for case sensitive non-ascii column names") {
// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkCaseSensitiveColStats(columnName = "列c")
// scalastyle:on
}
test("estimates the size of a test MetastoreRelation") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>