[SPARK-18505][SQL] Simplify AnalyzeColumnCommand

## What changes were proposed in this pull request?
I'm spending more time at the design & code level for cost-based optimizer now, and have found a number of issues related to maintainability and compatibility that I will like to address.

This is a small pull request to clean up AnalyzeColumnCommand:

1. Removed warning on duplicated columns. Warnings in log messages are useless since most users that run SQL don't see them.
2. Removed the nested updateStats function, by just inlining the function.
3. Renamed a few functions to better reflect what they do.
4. Removed the factory apply method for ColumnStatStruct. It is a bad pattern to use a apply method that returns an instantiation of a class that is not of the same type (ColumnStatStruct.apply used to return CreateNamedStruct).
5. Renamed ColumnStatStruct to just AnalyzeColumnCommand.
6. Added more documentation explaining some of the non-obvious return types and code blocks.

In follow-up pull requests, I'd like to address the following:

1. Get rid of the Map[String, ColumnStat] map, since internally we should be using Attribute to reference columns, rather than strings.
2. Decouple the fields exposed by ColumnStat and internals of Spark SQL's execution path. Currently the two are coupled because ColumnStat takes in an InternalRow.
3. Correctness: Remove code path that stores statistics in the catalog using the base64 encoding of the UnsafeRow format, which is not stable across Spark versions.
4. Clearly document the data representation stored in the catalog for statistics.

## How was this patch tested?
Affected test cases have been updated.

Author: Reynold Xin <rxin@databricks.com>

Closes #15933 from rxin/SPARK-18505.
This commit is contained in:
Reynold Xin 2016-11-18 16:34:11 -08:00
parent e5f5c29e02
commit 6f7ff75091
5 changed files with 74 additions and 56 deletions

View file

@ -17,8 +17,7 @@
package org.apache.spark.sql.execution.command
import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
@ -44,13 +43,16 @@ case class AnalyzeColumnCommand(
val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB))
relation match {
// Compute total size
val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match {
case catalogRel: CatalogRelation =>
updateStats(catalogRel.catalogTable,
// This is a Hive serde format table
(catalogRel.catalogTable,
AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable))
case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined =>
updateStats(logicalRel.catalogTable.get,
// This is a data source format table
(logicalRel.catalogTable.get,
AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get))
case otherRelation =>
@ -58,45 +60,45 @@ case class AnalyzeColumnCommand(
s"${otherRelation.nodeName}.")
}
def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = {
val (rowCount, columnStats) = computeColStats(sparkSession, relation)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
sizeInBytes = newTotalSize,
rowCount = Some(rowCount),
// Newly computed column stats should override the existing ones.
colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats)
sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics)))
// Refresh the cached data source table in the catalog.
sessionState.catalog.refreshTable(tableIdentWithDB)
}
// Compute stats for each column
val (rowCount, newColStats) =
AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
sizeInBytes = sizeInBytes,
rowCount = Some(rowCount),
// Newly computed column stats should override the existing ones.
colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats)
sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics)))
// Refresh the cached data source table in the catalog.
sessionState.catalog.refreshTable(tableIdentWithDB)
Seq.empty[Row]
}
}
object AnalyzeColumnCommand extends Logging {
/**
* Compute stats for the given columns.
* @return (row count, map from column name to ColumnStats)
*
* This is visible for testing.
*/
def computeColStats(
sparkSession: SparkSession,
relation: LogicalPlan): (Long, Map[String, ColumnStat]) = {
relation: LogicalPlan,
columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = {
// check correctness of column names
val attributesToAnalyze = mutable.MutableList[Attribute]()
val duplicatedColumns = mutable.MutableList[String]()
// Resolve the column names and dedup using AttributeSet
val resolver = sparkSession.sessionState.conf.resolver
columnNames.foreach { col =>
val attributesToAnalyze = AttributeSet(columnNames.map { col =>
val exprOption = relation.output.find(attr => resolver(attr.name, col))
val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col."))
// do deduplication
if (!attributesToAnalyze.contains(expr)) {
attributesToAnalyze += expr
} else {
duplicatedColumns += col
}
}
if (duplicatedColumns.nonEmpty) {
logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " +
s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " +
s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.")
}
exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col."))
}).toSeq
// Collect statistics per column.
// The first element in the result will be the overall row count, the following elements
@ -104,22 +106,21 @@ case class AnalyzeColumnCommand(
// The layout of each struct follows the layout of the ColumnStats.
val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError
val expressions = Count(Literal(1)).toAggregateExpression() +:
attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr))
attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr))
val namedExpressions = expressions.map(e => Alias(e, e.toString)())
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation))
.queryExecution.toRdd.collect().head
// unwrap the result
// TODO: Get rid of numFields by using the public Dataset API.
val rowCount = statsRow.getLong(0)
val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
val numFields = ColumnStatStruct.numStatFields(expr.dataType)
val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType)
(expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields)))
}.toMap
(rowCount, columnStats)
}
}
object ColumnStatStruct {
private val zero = Literal(0, LongType)
private val one = Literal(1, LongType)
@ -137,7 +138,11 @@ object ColumnStatStruct {
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = {
/**
* Creates a struct that groups the sequence of expressions together. This is used to create
* one top level struct per column.
*/
private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = {
CreateStruct(exprs.map { expr: Expression =>
expr.transformUp {
case af: AggregateFunction => af.toAggregateExpression()
@ -161,6 +166,7 @@ object ColumnStatStruct {
Seq(numNulls(e), numTrues(e), numFalses(e))
}
// TODO(rxin): Get rid of this function.
def numStatFields(dataType: DataType): Int = {
dataType match {
case BinaryType | BooleanType => 3
@ -168,14 +174,25 @@ object ColumnStatStruct {
}
}
def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match {
// Use aggregate functions to compute statistics we need.
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD))
case StringType => getStruct(stringColumnStat(attr, relativeSD))
case BinaryType => getStruct(binaryColumnStat(attr))
case BooleanType => getStruct(booleanColumnStat(attr))
case otherType =>
throw new AnalysisException("Analyzing columns is not supported for column " +
s"${attr.name} of data type: ${attr.dataType}.")
/**
* Creates a struct expression that contains the statistics to collect for a column.
*
* @param attr column to collect statistics
* @param relativeSD relative error for approximate number of distinct values.
*/
def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = {
attr.dataType match {
case _: NumericType | TimestampType | DateType =>
createStruct(numericColumnStat(attr, relativeSD))
case StringType =>
createStruct(stringColumnStat(attr, relativeSD))
case BinaryType =>
createStruct(binaryColumnStat(attr))
case BooleanType =>
createStruct(booleanColumnStat(attr))
case otherType =>
throw new AnalysisException("Analyzing columns is not supported for column " +
s"${attr.name} of data type: ${attr.dataType}.")
}
}
}

View file

@ -79,7 +79,7 @@ class StatisticsColumnSuite extends StatisticsTest {
val tableIdent = TableIdentifier(table, Some("default"))
val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
val (_, columnStats) =
AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation)
AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze)
assert(columnStats.contains(colName1))
assert(columnStats.contains(colName2))
// check deduplication

View file

@ -19,11 +19,12 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct}
import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
trait StatisticsTest extends QueryTest with SharedSQLContext {
def checkColStats(
@ -36,7 +37,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext {
val tableIdent = TableIdentifier(table, Some("default"))
val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
val (_, columnStats) =
AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation)
AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name))
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
assert(columnStats.contains(field.name))
val colStat = columnStats(field.name)
@ -48,7 +49,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext {
// check if we get the same colStat after encoding and decoding
val encodedCS = colStat.toString
val numFields = ColumnStatStruct.numStatFields(field.dataType)
val numFields = AnalyzeColumnCommand.numStatFields(field.dataType)
val decodedCS = ColumnStat(numFields, encodedCS)
StatisticsTest.checkColStat(
dataType = field.dataType,

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils}
import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils}
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.StaticSQLConf._
@ -634,7 +634,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
.map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) }
val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect {
case f if colStatsProps.contains(f.name) =>
val numFields = ColumnStatStruct.numStatFields(f.dataType)
val numFields = AnalyzeColumnCommand.numStatFields(f.dataType)
(f.name, ColumnStat(numFields, colStatsProps(f.name)))
}.toMap
tableWithSchema.copy(

View file

@ -97,7 +97,7 @@ private[hive] class HiveClientImpl(
}
// Create an internal session state for this HiveClientImpl.
val state = {
val state: SessionState = {
val original = Thread.currentThread().getContextClassLoader
// Switch to the initClassLoader.
Thread.currentThread().setContextClassLoader(initClassLoader)