diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f3e2147b8f..79865609cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.commons.codec.binary.Base64 +import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.types._ + /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override @@ -58,60 +61,175 @@ case class Statistics( } } + /** - * Statistics for a column. + * Statistics collected for a column. + * + * 1. Supported data types are defined in `ColumnStat.supportsType`. + * 2. The JVM data type stored in min/max is the external data type (used in Row) for the + * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for + * TimestampType we store java.sql.Timestamp. + * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. + * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * (sketches) might have been used, and the data collected can also be stale. + * + * @param distinctCount number of distinct values + * @param min minimum value + * @param max maximum value + * @param nullCount number of nulls + * @param avgLen average length of the values. For fixed-length types, this should be a constant. + * @param maxLen maximum length of the values. For fixed-length types, this should be a constant. */ -case class ColumnStat(statRow: InternalRow) { +case class ColumnStat( + distinctCount: BigInt, + min: Option[Any], + max: Option[Any], + nullCount: BigInt, + avgLen: Long, + maxLen: Long) { - def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { - NumericColumnStat(statRow, dataType) - } - def forString: StringColumnStat = StringColumnStat(statRow) - def forBinary: BinaryColumnStat = BinaryColumnStat(statRow) - def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow) + // We currently don't store min/max for binary/string type. This can change in the future and + // then we need to remove this require. + require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) + require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - override def toString: String = { - // use Base64 for encoding - Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes) + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string + * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * In the case min/max values are null (None), they won't appear in the map. + */ + def toMap: Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(ColumnStat.KEY_VERSION, "1") + map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) + map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) + map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) + map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + map.toMap } } -object ColumnStat { - def apply(numFields: Int, str: String): ColumnStat = { - // use Base64 for decoding - val bytes = Base64.decodeBase64(str) - val unsafeRow = new UnsafeRow(numFields) - unsafeRow.pointTo(bytes, bytes.length) - ColumnStat(unsafeRow) + +object ColumnStat extends Logging { + + // List of string keys used to serialize ColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + + /** Returns true iff the we support gathering column statistics on column of the given type. */ + def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false } -} -case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) { - // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`. - val numNulls: Long = statRow.getLong(0) - val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType] - val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType] - val ndv: Long = statRow.getLong(3) -} + /** + * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats + * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. + */ + def fromMap(table: String, field: StructField, map: Map[String, String]) + : Option[ColumnStat] = { + val str2val: (String => Any) = field.dataType match { + case _: IntegralType => _.toLong + case _: DecimalType => new java.math.BigDecimal(_) + case DoubleType | FloatType => _.toDouble + case BooleanType => _.toBoolean + case DateType => java.sql.Date.valueOf + case TimestampType => java.sql.Timestamp.valueOf + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => _ => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column ${field.name} of data type: ${field.dataType}.") + } -case class StringColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. - val numNulls: Long = statRow.getLong(0) - val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getInt(2) - val ndv: Long = statRow.getLong(3) -} + try { + Some(ColumnStat( + distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), + // Note that flatMap(Option.apply) turns Option(null) into None. + min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + nullCount = BigInt(map(KEY_NULL_COUNT).toLong), + avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, + maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) + None + } + } -case class BinaryColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. - val numNulls: Long = statRow.getLong(0) - val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getInt(2) -} + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + + def fixedLenTypeStruct(castType: DataType) = { + // For fixed width types, avg size should be the same as max size. + val avgSize = Literal(col.dataType.defaultSize, LongType) + struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct(LongType) + case _: DecimalType => fixedLenTypeStruct(col.dataType) + case DoubleType | FloatType => fixedLenTypeStruct(DoubleType) + case BooleanType => fixedLenTypeStruct(col.dataType) + case DateType => fixedLenTypeStruct(col.dataType) + case TimestampType => fixedLenTypeStruct(col.dataType) + case BinaryType | StringType => + // For string and binary type, we don't store min/max. + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType)) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ + def rowToColumnStat(row: Row): ColumnStat = { + ColumnStat( + distinctCount = BigInt(row.getLong(0)), + min = Option(row.get(1)), // for string/binary min/max, get should return null + max = Option(row.get(2)), + nullCount = BigInt(row.getLong(3)), + avgLen = row.getLong(4), + maxLen = row.getLong(5) + ) + } -case class BooleanColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`. - val numNulls: Long = statRow.getLong(0) - val numTrues: Long = statRow.getLong(1) - val numFalses: Long = statRow.getLong(2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7fc57d09e9..9dffe3614a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -24,9 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.types._ /** @@ -62,7 +61,7 @@ case class AnalyzeColumnCommand( // Compute stats for each column val (rowCount, newColStats) = - AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames) // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = Statistics( @@ -88,8 +87,9 @@ object AnalyzeColumnCommand extends Logging { * * This is visible for testing. */ - def computeColStats( + def computeColumnStats( sparkSession: SparkSession, + tableName: String, relation: LogicalPlan, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { @@ -97,102 +97,33 @@ object AnalyzeColumnCommand extends Logging { val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) }).toSeq + // Make sure the column types are supported for stats gathering. + attributesToAnalyze.foreach { attr => + if (!ColumnStat.supportsType(attr.dataType)) { + throw new AnalysisException( + s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " + + "and Spark does not support statistics collection on this column type.") + } + } + // Collect statistics per column. // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) - val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) - .queryExecution.toRdd.collect().head + attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() - // unwrap the result - // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) - (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) + (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) }.toMap (rowCount, columnStats) } - - private val zero = Literal(0, LongType) - private val one = Literal(1, LongType) - - private def numNulls(e: Expression): Expression = { - if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero - } - private def max(e: Expression): Expression = Max(e) - private def min(e: Expression): Expression = Min(e) - private def ndv(e: Expression, relativeSD: Double): Expression = { - // the approximate ndv should never be larger than the number of rows - Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) - } - private def avgLength(e: Expression): Expression = Average(Length(e)) - private def maxLength(e: Expression): Expression = Max(Length(e)) - private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) - private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - - /** - * Creates a struct that groups the sequence of expressions together. This is used to create - * one top level struct per column. - */ - private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { - CreateStruct(exprs.map { expr: Expression => - expr.transformUp { - case af: AggregateFunction => af.toAggregateExpression() - } - }) - } - - private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { - Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) - } - - private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { - Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) - } - - private def binaryColumnStat(e: Expression): Seq[Expression] = { - Seq(numNulls(e), avgLength(e), maxLength(e)) - } - - private def booleanColumnStat(e: Expression): Seq[Expression] = { - Seq(numNulls(e), numTrues(e), numFalses(e)) - } - - // TODO(rxin): Get rid of this function. - def numStatFields(dataType: DataType): Int = { - dataType match { - case BinaryType | BooleanType => 3 - case _ => 4 - } - } - - /** - * Creates a struct expression that contains the statistics to collect for a column. - * - * @param attr column to collect statistics - * @param relativeSD relative error for approximate number of distinct values. - */ - def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { - attr.dataType match { - case _: NumericType | TimestampType | DateType => - createStruct(numericColumnStat(attr, relativeSD)) - case StringType => - createStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => - createStruct(binaryColumnStat(attr)) - case BooleanType => - createStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala new file mode 100644 index 0000000000..1fcccd0610 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.{lang => jl} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.test.SQLTestData.ArrayData +import org.apache.spark.sql.types._ + + +/** + * End-to-end suite testing statistics collection and use on both entire table and columns. + */ +class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { + import testImplicits._ + + private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) + : Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } + + test("estimates the size of a limit 0 on outer join") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + val df1 = spark.table("test") + val df2 = spark.table("test").limit(0) + val df = df1.join(df2, Seq("k"), "left") + + val sizes = df.queryExecution.analyzed.collect { case g: Join => + g.statistics.sizeInBytes + } + + assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") + assert(sizes.head === BigInt(96), + s"expected exact size 96 for table 'test', got: ${sizes.head}") + } + } + + test("analyze column command - unsupported types and invalid columns") { + val tableName = "column_stats_test1" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + + // Test unsupported data types + val err1 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err1.message.contains("does not support statistics collection")) + + // Test invalid columns + val err2 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column") + } + assert(err2.message.contains("does not exist")) + } + } + + test("test table-level statistics for data source table") { + val tableName = "tbl" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) + + // noscan won't count the number of rows + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + checkTableStats(tableName, expectedRowCount = None) + + // without noscan, we count the number of rows + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + checkTableStats(tableName, expectedRowCount = Some(2)) + } + } + + test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { + val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) + val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) + assert(df.queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + } + + test("estimates the size of limit") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => + val df = sql(s"""SELECT * FROM test limit $limit""") + + val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => + g.statistics.sizeInBytes + } + assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesGlobalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") + + val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => + l.statistics.sizeInBytes + } + assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesLocalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") + } + } + } + +} + + +/** + * The base for test cases that we want to include in both the hive module (for verifying behavior + * when using the Hive external catalog) as well as in the sql/core module. + */ +abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { + import testImplicits._ + + private val dec1 = new java.math.BigDecimal("1.000000000000000000") + private val dec2 = new java.math.BigDecimal("8.000000000000000000") + private val d1 = Date.valueOf("2016-05-08") + private val d2 = Date.valueOf("2016-05-09") + private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** + * Define a very simple 3 row table used for testing column serialization. + * Note: last column is seq[int] which doesn't support stats collection. + */ + protected val data = Seq[ + (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, + jl.Double, jl.Float, java.math.BigDecimal, + String, Array[Byte], Date, Timestamp, + Seq[Int])]( + (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), + (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), + (null, null, null, null, null, null, null, null, null, null, null, null, null) + ) + + /** A mapping from column to the stats collected. */ + protected val stats = mutable.LinkedHashMap( + "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), + "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), + "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), + "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cstring" -> ColumnStat(2, None, None, 1, 3, 3), + "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), + "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + ) + + test("column stats round trip serialization") { + // Make sure we serialize and then deserialize and we will get the result data + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + stats.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + assert(roundtrip == Some(v)) + } + } + } + + test("analyze column command - result verification") { + val tableName = "column_stats_test2" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + + // Validate statistics + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == stats.size) + + stats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala deleted file mode 100644 index e866ac2cb3..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand -import org.apache.spark.sql.test.SQLTestData.ArrayData -import org.apache.spark.sql.types._ - -class StatisticsColumnSuite extends StatisticsTest { - import testImplicits._ - - test("parse analyze column commands") { - val tableName = "tbl" - - // we need to specify column names - intercept[ParseException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") - } - - val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) - val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) - comparePlans(parsed, expected) - } - - test("analyzing columns of non-atomic types is not supported") { - val tableName = "tbl" - withTable(tableName) { - Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) - val err = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") - } - assert(err.message.contains("Analyzing columns is not supported")) - } - } - - test("check correctness of columns") { - val table = "tbl" - val colName1 = "abc" - val colName2 = "x.yz" - withTable(table) { - sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") - - val invalidColError = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") - } - assert(invalidColError.message == "Invalid column name: key.") - - withSQLConf("spark.sql.caseSensitive" -> "true") { - val invalidErr = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") - } - assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") - } - - withSQLConf("spark.sql.caseSensitive" -> "false") { - val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) - val tableIdent = TableIdentifier(table, Some("default")) - val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val (_, columnStats) = - AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) - assert(columnStats.contains(colName1)) - assert(columnStats.contains(colName2)) - // check deduplication - assert(columnStats.size == 2) - assert(!columnStats.contains(colName2.toUpperCase)) - } - } - } - - private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { - values.filter(_.isDefined).map(_.get) - } - - test("column-level statistics for integral type columns") { - val values = (0 to 5).map { i => - if (i % 2 == 0) None else Some(i) - } - val data = values.map { i => - (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) - } - - val df = data.toDF("c1", "c2", "c3", "c4") - val nonNullValues = getNonNullValues[Int](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.max, - nonNullValues.min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for fractional type columns") { - val values: Seq[Option[Decimal]] = (0 to 5).map { i => - if (i == 0) None else Some(Decimal(i + i * 0.01)) - } - val data = values.map { i => - (i.map(_.toFloat), i.map(_.toDouble), i) - } - - val df = data.toDF("c1", "c2", "c3") - val nonNullValues = getNonNullValues[Decimal](values) - val numNulls = values.count(_.isEmpty).toLong - val ndv = nonNullValues.distinct.length.toLong - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case floatType: FloatType => - ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, - ndv)) - case doubleType: DoubleType => - ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, - ndv)) - case decimalType: DecimalType => - ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) - } - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for string column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[String](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toInt, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for binary column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Array[Byte]](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toInt)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for boolean column") { - val values = Seq(None, Some(true), Some(false), Some(true)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Boolean](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.count(_.equals(true)).toLong, - nonNullValues.count(_.equals(false)).toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for date column") { - val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Date](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - // Internally, DateType is represented as the number of days from 1970-01-01. - nonNullValues.map(DateTimeUtils.fromJavaDate).max, - nonNullValues.map(DateTimeUtils.fromJavaDate).min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for timestamp column") { - val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => - i.map(Timestamp.valueOf) - } - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Timestamp](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - // Internally, TimestampType is represented as the number of days from 1970-01-01 - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for null columns") { - val values = Seq(None, None) - val data = values.map { i => - (i.map(_.toString), i.map(_.toString.toInt)) - } - val df = data.toDF("c1", "c2") - val expectedColStatsSeq = df.schema.map { f => - (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for columns with different types") { - val intSeq = Seq(1, 2) - val doubleSeq = Seq(1.01d, 2.02d) - val stringSeq = Seq("a", "bb") - val binarySeq = Seq("a", "bb").map(_.getBytes) - val booleanSeq = Seq(true, false) - val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) - val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) - val longSeq = Seq(5L, 4L) - - val data = intSeq.indices.map { i => - (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), - timestampSeq(i), longSeq(i)) - } - val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case DoubleType => - ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, - doubleSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BinaryType => - ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toInt)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - case DateType => - ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, - dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) - case TimestampType => - ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, - timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, - timestampSeq.distinct.length.toLong)) - case LongType => - ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) - } - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("update table-level stats while collecting column-level stats") { - val table = "tbl" - withTable(table) { - sql(s"CREATE TABLE $table (c1 int) USING PARQUET") - sql(s"INSERT INTO $table SELECT 1") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - checkTableStats(tableName = table, expectedRowCount = Some(1)) - - // update table-level stats between analyze table and analyze column commands - sql(s"INSERT INTO $table SELECT 1") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) - - val colStat = fetchedStats.get.colStats("c1") - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = colStat, - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - } - } - - test("analyze column stats independently") { - val table = "tbl" - withTable(table) { - sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) - assert(fetchedStats1.get.colStats.size == 1) - val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) - val rsd = spark.sessionState.conf.ndvMaxError - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = fetchedStats1.get.colStats("c1"), - expectedColStat = expected1, - rsd = rsd) - - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") - val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) - // column c1 is kept in the stats - assert(fetchedStats2.get.colStats.size == 2) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = fetchedStats2.get.colStats("c1"), - expectedColStat = expected1, - rsd = rsd) - val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) - StatisticsTest.checkColStat( - dataType = LongType, - colStat = fetchedStats2.get.colStats("c2"), - expectedColStat = expected2, - rsd = rsd) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala deleted file mode 100644 index 8cf42e9248..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.types._ - -class StatisticsSuite extends StatisticsTest { - import testImplicits._ - - test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { - val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) - val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.statistics.sizeInBytes > - spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > - spark.sessionState.conf.autoBroadcastJoinThreshold) - } - - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.statistics.sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.statistics.sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - - test("estimates the size of a limit 0 on outer join") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - val df1 = spark.table("test") - val df2 = spark.table("test").limit(0) - val df = df1.join(df2, Seq("k"), "left") - - val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.statistics.sizeInBytes - } - - assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), - s"expected exact size 96 for table 'test', got: ${sizes.head}") - } - } - - test("test table-level statistics for data source table created in InMemoryCatalog") { - val tableName = "tbl" - withTable(tableName) { - sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) - - // noscan won't count the number of rows - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) - - // without noscan, we count the number of rows - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala deleted file mode 100644 index 915ee0d31b..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - - -trait StatisticsTest extends QueryTest with SharedSQLContext { - - def checkColStats( - df: DataFrame, - expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { - val table = "tbl" - withTable(table) { - df.write.format("json").saveAsTable(table) - val columns = expectedColStatsSeq.map(_._1) - val tableIdent = TableIdentifier(table, Some("default")) - val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val (_, columnStats) = - AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) - expectedColStatsSeq.foreach { case (field, expectedColStat) => - assert(columnStats.contains(field.name)) - val colStat = columnStats(field.name) - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = colStat, - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - - // check if we get the same colStat after encoding and decoding - val encodedCS = colStat.toString - val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) - val decodedCS = ColumnStat(numFields, encodedCS) - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = decodedCS, - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - } - } - } - - def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } -} - -object StatisticsTest { - def checkColStat( - dataType: DataType, - colStat: ColumnStat, - expectedColStat: ColumnStat, - rsd: Double): Unit = { - dataType match { - case StringType => - val cs = colStat.forString - val expectedCS = expectedColStat.forString - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.avgColLen == expectedCS.avgColLen) - assert(cs.maxColLen == expectedCS.maxColLen) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) - case BinaryType => - val cs = colStat.forBinary - val expectedCS = expectedColStat.forBinary - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.avgColLen == expectedCS.avgColLen) - assert(cs.maxColLen == expectedCS.maxColLen) - case BooleanType => - val cs = colStat.forBoolean - val expectedCS = expectedColStat.forBoolean - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.numTrues == expectedCS.numTrues) - assert(cs.numFalses == expectedCS.numFalses) - case atomicType: AtomicType => - checkNumericColStats( - dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) - } - } - - private def checkNumericColStats( - dataType: AtomicType, - colStat: ColumnStat, - expectedColStat: ColumnStat, - rsd: Double): Unit = { - val cs = colStat.forNumeric(dataType) - val expectedCS = expectedColStat.forNumeric(dataType) - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.max == expectedCS.max) - assert(cs.min == expectedCS.min) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) - } - - private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { - // ndv is an approximate value, so we make sure we have the value, and it should be - // within 3*SD's of the given rsd. - if (expectedNdv == 0) { - assert(ndv == 0) - } else if (expectedNdv > 0) { - assert(ndv > 0) - val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) - assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 797fe9ffa8..b070138be0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand, - DescribeTableCommand, ShowFunctionsCommand} -import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} +import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -221,12 +220,22 @@ class SparkSqlParserSuite extends PlanTest { intercept("explain describe tables x", "Unsupported SQL statement") } - test("SPARK-18106 analyze table") { + test("analyze table statistics") { assertEqual("analyze table t compute statistics", AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) assertEqual("analyze table t compute statistics noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) - assertEqual("analyze table t partition (a) compute statistics noscan", + assertEqual("analyze table t partition (a) compute statistics nOscAn", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + + // Partitions specified - we currently parse them but don't do anything with it + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) intercept("analyze table t compute statistics xxxx", @@ -234,4 +243,11 @@ class SparkSqlParserSuite extends PlanTest { intercept("analyze table t partition (a) compute statistics xxxx", "Expected `NOSCAN` instead of `xxxx`") } + + test("analyze table column statistics") { + intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "") + + assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", + AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ff0923f048..fd9dc32063 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -514,7 +514,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } stats.colStats.foreach { case (colName, colStat) => - statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString + colStat.toMap.foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -605,48 +607,65 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * It reads table schema, provider, partition column names and bucket specification from table * properties, and filter out these special entries from table properties. */ - private def restoreTableMetadata(table: CatalogTable): CatalogTable = { + private def restoreTableMetadata(inputTable: CatalogTable): CatalogTable = { if (conf.get(DEBUG_MODE)) { - return table + return inputTable } - val tableWithSchema = if (table.tableType == VIEW) { - table - } else { - getProviderFromTableProperties(table) match { + var table = inputTable + + if (table.tableType != VIEW) { + table.properties.get(DATASOURCE_PROVIDER) match { // No provider in table properties, which means this table is created by Spark prior to 2.1, // or is created at Hive side. case None => - table.copy(provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) + table = table.copy( + provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) // This is a Hive serde table created by Spark 2.1 or higher versions. - case Some(DDLUtils.HIVE_PROVIDER) => restoreHiveSerdeTable(table) + case Some(DDLUtils.HIVE_PROVIDER) => + table = restoreHiveSerdeTable(table) // This is a regular data source table. - case Some(provider) => restoreDataSourceTable(table, provider) + case Some(provider) => + table = restoreDataSourceTable(table, provider) } } // construct Spark's statistics from information in Hive metastore - val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) - val tableWithStats = if (statsProps.nonEmpty) { - val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) - .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } - val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { - case f if colStatsProps.contains(f.name) => - val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) - (f.name, ColumnStat(numFields, colStatsProps(f.name))) - }.toMap - tableWithSchema.copy( + val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + + if (statsProps.nonEmpty) { + val colStats = new scala.collection.mutable.HashMap[String, ColumnStat] + + // For each column, recover its column stats. Note that this is currently a O(n^2) operation, + // but given the number of columns it usually not enormous, this is probably OK as a start. + // If we want to map this a linear operation, we'd need a stronger contract between the + // naming convention used for serialization. + table.schema.foreach { field => + if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { + // If "version" field is defined, then the column stat is defined. + val keyPrefix = columnStatKeyPropName(field.name, "") + val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => + (k.drop(keyPrefix.length), v) + } + + ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach { + colStat => colStats += field.name -> colStat + } + } + } + + table = table.copy( stats = Some(Statistics( - sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)), - rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), - colStats = colStats))) - } else { - tableWithSchema + sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)), + rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats.toMap))) } - tableWithStats.copy(properties = getOriginalTableProperties(table)) + // Get the original table properties as defined by the user. + table.copy( + properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) } private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { @@ -1020,17 +1039,17 @@ object HiveExternalCatalog { val TABLE_PARTITION_PROVIDER_CATALOG = "catalog" val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem" - - def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = { - metadata.properties.get(DATASOURCE_PROVIDER) - } - - def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) } + /** + * Returns the fully qualified name used in table properties for a particular column stat. + * For example, for column "mycol", and "min" stat, this should return + * "spark.sql.statistics.colStats.mycol.min". + */ + private def columnStatKeyPropName(columnName: String, statKey: String): String = { + STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey } // A persisted data source table always store its schema in the catalog. - def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { + private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." val props = metadata.properties val schema = props.get(DATASOURCE_SCHEMA) @@ -1078,11 +1097,11 @@ object HiveExternalCatalog { ) } - def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { + private def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { getColumnNamesByType(metadata.properties, "part", "partitioning columns") } - def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { + private def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets => BucketSpec( numBuckets.toInt, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 4f5ebc3d83..5ae202fdc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,56 +22,16 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { - - test("parse analyze commands") { - def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) - val operators = parsed.collect { - case a: AnalyzeTableCommand => a - case o => o - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail( - s"""$analyzeCommand expected command: $c, but got ${operators(0)} - |parsed command: - |$parsed - """.stripMargin) - } - } - - assertAnalyzeCommand( - "ANALYZE TABLE Table1 COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", - classOf[AnalyzeTableCommand]) - - assertAnalyzeCommand( - "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", - classOf[AnalyzeTableCommand]) - } +class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("MetastoreRelations fallback to HDFS for size estimation") { val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled @@ -310,6 +270,110 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + test("verify serialized column stats after analyzing columns") { + import testImplicits._ + + val tableName = "column_stats_test2" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + + // Validate statistics + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val table = hiveClient.getTable("default", tableName) + + val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) + assert(props == Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + )) + } + } + private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = { test("test table-level statistics for " + tableDescription) { val parquetTable = "parquetTable" @@ -319,7 +383,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier(parquetTable)) assert(DDLUtils.isDatasourceTable(catalogTable)) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + // Add a filter to avoid creating too many partitions + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) @@ -328,7 +393,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val fetchedStats1 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") val fetchedStats2 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) @@ -340,7 +405,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils parquetTable, isDataSourceTable = true, hasSizeInBytes = true, - expectedRowCounts = Some(1000)) + expectedRowCounts = Some(20)) assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -369,6 +434,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + /** Used to test refreshing cached metadata once table stats are updated. */ private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = { val tableName = "tbl" var statsBeforeUpdate: Statistics = null @@ -411,145 +477,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils assert(statsAfterUpdate.rowCount == Some(2)) } - test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") { - val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true) - - assert(statsBeforeUpdate.sizeInBytes > 0) - assert(statsBeforeUpdate.rowCount == Some(1)) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = statsBeforeUpdate.colStats("key"), - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - - assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) - assert(statsAfterUpdate.rowCount == Some(2)) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = statsAfterUpdate.colStats("key"), - expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)), - rsd = spark.sessionState.conf.ndvMaxError) - } - - private lazy val (testDataFrame, expectedColStatsSeq) = { - import testImplicits._ - - val intSeq = Seq(1, 2) - val stringSeq = Seq("a", "bb") - val binarySeq = Seq("a", "bb").map(_.getBytes) - val booleanSeq = Seq(true, false) - val data = intSeq.indices.map { i => - (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i)) - } - val df: DataFrame = data.toDF("c1", "c2", "c3", "c4") - val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BinaryType => - ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toInt)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - } - (f, colStat) - } - (df, expectedColStatsSeq) - } - - private def checkColStats( - tableName: String, - isDataSourceTable: Boolean, - expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { - val readback = spark.table(tableName) - val stats = readback.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats.get - case rel: LogicalRelation => - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats.get - } - assert(stats.length == 1) - val columnStats = stats.head.colStats - assert(columnStats.size == expectedColStatsSeq.length) - expectedColStatsSeq.foreach { case (field, expectedColStat) => - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = columnStats(field.name), - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - } - } - - test("generate and load column-level stats for data source table") { - val dsTable = "dsTable" - withTable(dsTable) { - testDataFrame.write.format("parquet").saveAsTable(dsTable) - sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") - checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq) - } - } - - test("generate and load column-level stats for hive serde table") { - val hTable = "hTable" - val tmp = "tmp" - withTable(hTable, tmp) { - testDataFrame.write.format("parquet").saveAsTable(tmp) - sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE") - sql(s"INSERT INTO $hTable SELECT * FROM $tmp") - sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") - checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq) - } - } - - // When caseSensitive is on, for columns with only case difference, they are different columns - // and we should generate column stats for all of them. - private def checkCaseSensitiveColStats(columnName: String): Unit = { - val tableName = "tbl" - withTable(tableName) { - val column1 = columnName.toLowerCase - val column2 = columnName.toUpperCase - withSQLConf("spark.sql.caseSensitive" -> "true") { - sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET") - sql(s"INSERT INTO $tableName SELECT 1, 3.0") - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`") - val readback = spark.table(tableName) - val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => - val columnStats = rel.catalogTable.get.stats.get.colStats - assert(columnStats.size == 2) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = columnStats(column1), - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - StatisticsTest.checkColStat( - dataType = DoubleType, - colStat = columnStats(column2), - expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - rel - } - assert(relations.size == 1) - } - } - } - - test("check column statistics for case sensitive column names") { - checkCaseSensitiveColStats(columnName = "c1") - } - - test("check column statistics for case sensitive non-ascii column names") { - // scalastyle:off - // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkCaseSensitiveColStats(columnName = "列c") - // scalastyle:on - } - test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>