[SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan

## What changes were proposed in this pull request?

After wiring `SQLConf` in logical plan ([PR 18299](https://github.com/apache/spark/pull/18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`.

## How was this patch tested?

Covered by existing tests, plus some modified existing tests.

Author: wangzhenhua <wangzhenhua@huawei.com>
Author: Zhenhua Wang <wzh_zju@163.com>

Closes #18391 from wzhfy/removeConf.
This commit is contained in:
wangzhenhua 2017-06-23 10:33:53 -07:00 committed by Xiao Li
parent 07479b3cfb
commit b803b66a81
38 changed files with 176 additions and 171 deletions

View file

@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
@ -436,7 +435,7 @@ case class CatalogRelation(
createTime = -1 createTime = -1
)) ))
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
// For data source tables, we will create a `LogicalRelation` and won't call this method, for // For data source tables, we will create a `LogicalRelation` and won't call this method, for
// hive serde tables, we will always generate a statistics. // hive serde tables, we will always generate a statistics.
// TODO: unify the table stats generation. // TODO: unify the table stats generation.

View file

@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
// Do reordering if the number of items is appropriate and join conditions exist. // Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated. // We also need to check if costs of all items can be evaluated.
if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
items.forall(_.stats(conf).rowCount.isDefined)) { items.forall(_.stats.rowCount.isDefined)) {
JoinReorderDP.search(conf, items, conditions, output) JoinReorderDP.search(conf, items, conditions, output)
} else { } else {
plan plan
@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
/** Get the cost of the root node of this plan tree. */ /** Get the cost of the root node of this plan tree. */
def rootCost(conf: SQLConf): Cost = { def rootCost(conf: SQLConf): Cost = {
if (itemIds.size > 1) { if (itemIds.size > 1) {
val rootStats = plan.stats(conf) val rootStats = plan.stats
Cost(rootStats.rowCount.get, rootStats.sizeInBytes) Cost(rootStats.rowCount.get, rootStats.sizeInBytes)
} else { } else {
// If the plan is a leaf item, it has zero cost. // If the plan is a leaf item, it has zero cost.

View file

@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
case FullOuter => case FullOuter =>
(left.maxRows, right.maxRows) match { (left.maxRows, right.maxRows) match {
case (None, None) => case (None, None) =>
if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) { if (left.stats.sizeInBytes >= right.stats.sizeInBytes) {
join.copy(left = maybePushLimit(exp, left)) join.copy(left = maybePushLimit(exp, left))
} else { } else {
join.copy(right = maybePushLimit(exp, right)) join.copy(right = maybePushLimit(exp, right))

View file

@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
// Find if the input plans are eligible for star join detection. // Find if the input plans are eligible for star join detection.
// An eligible plan is a base table access with valid statistics. // An eligible plan is a base table access with valid statistics.
val foundEligibleJoin = input.forall { val foundEligibleJoin = input.forall {
case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true
case _ => false case _ => false
} }
@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan) val leafCol = findLeafNodeCol(column, plan)
leafCol match { leafCol match {
case Some(col) if t.outputSet.contains(col) => case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf) val stats = t.stats
stats.rowCount match { stats.rowCount match {
case Some(rowCount) if rowCount >= 0 => case Some(rowCount) if rowCount >= 0 =>
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan) val leafCol = findLeafNodeCol(column, plan)
leafCol match { leafCol match {
case Some(col) if t.outputSet.contains(col) => case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf) val stats = t.stats
stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
case None => false case None => false
} }
@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
*/ */
private def getTableAccessCardinality( private def getTableAccessCardinality(
input: LogicalPlan): Option[BigInt] = input match { input: LogicalPlan): Option[BigInt] = input match {
case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined =>
if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { if (conf.cboEnabled && input.stats.rowCount.isDefined) {
Option(input.stats(conf).rowCount.get) Option(input.stats.rowCount.get)
} else { } else {
Option(t.stats(conf).rowCount.get) Option(t.stats.rowCount.get)
} }
case _ => None case _ => None
} }

View file

@ -21,7 +21,6 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation { object LocalRelation {
@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
} }
} }
override def computeStats(conf: SQLConf): Statistics = override def computeStats: Statistics =
Statistics(sizeInBytes = Statistics(sizeInBytes =
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)

View file

@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
* first time. If the configuration changes, the cache can be invalidated by calling * first time. If the configuration changes, the cache can be invalidated by calling
* [[invalidateStatsCache()]]. * [[invalidateStatsCache()]].
*/ */
final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { final def stats: Statistics = statsCache.getOrElse {
statsCache = Some(computeStats(conf)) statsCache = Some(computeStats)
statsCache.get statsCache.get
} }
@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
* *
* [[LeafNode]]s must override this. * [[LeafNode]]s must override this.
*/ */
protected def computeStats(conf: SQLConf): Statistics = { protected def computeStats: Statistics = {
if (children.isEmpty) { if (children.isEmpty) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
} }
Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
} }
override def verboseStringWithSuffix: String = { override def verboseStringWithSuffix: String = {
@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan {
override protected def validConstraints: Set[Expression] = child.constraints override protected def validConstraints: Set[Expression] = child.constraints
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is // There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error. // no columns, this help to prevent divide-by-zero error.
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
// Assume there will be the same number of rows as child has. // Assume there will be the same number of rows as child has.
var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) { if (sizeInBytes == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children). // (product of children).
@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan {
} }
// Don't propagate rowCount and attributeStats, since they are not estimated here. // Don't propagate rowCount and attributeStats, since they are not estimated here.
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
} }
} }

View file

@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.util.random.RandomSampler import org.apache.spark.util.random.RandomSampler
@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
override def validConstraints: Set[Expression] = override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList)) child.constraints.union(getAliasedConstraints(projectList))
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
if (conf.cboEnabled) { if (conf.cboEnabled) {
ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) ProjectEstimation.estimate(this).getOrElse(super.computeStats)
} else { } else {
super.computeStats(conf) super.computeStats
} }
} }
} }
@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
child.constraints.union(predicates.toSet) child.constraints.union(predicates.toSet)
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
if (conf.cboEnabled) { if (conf.cboEnabled) {
FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) FilterEstimation(this).estimate.getOrElse(super.computeStats)
} else { } else {
super.computeStats(conf) super.computeStats
} }
} }
} }
@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
} }
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val leftSize = left.stats(conf).sizeInBytes val leftSize = left.stats.sizeInBytes
val rightSize = right.stats(conf).sizeInBytes val rightSize = right.stats.sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
Statistics( Statistics(
sizeInBytes = sizeInBytes, sizeInBytes = sizeInBytes,
hints = left.stats(conf).hints.resetForJoin()) hints = left.stats.hints.resetForJoin())
} }
} }
@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override protected def validConstraints: Set[Expression] = leftConstraints override protected def validConstraints: Set[Expression] = leftConstraints
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
left.stats(conf).copy() left.stats.copy()
} }
} }
@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
children.length > 1 && childrenResolved && allChildrenCompatible children.length > 1 && childrenResolved && allChildrenCompatible
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum val sizeInBytes = children.map(_.stats.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes) Statistics(sizeInBytes = sizeInBytes)
} }
@ -357,20 +356,20 @@ case class Join(
case _ => resolvedExceptNatural case _ => resolvedExceptNatural
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
def simpleEstimation: Statistics = joinType match { def simpleEstimation: Statistics = joinType match {
case LeftAnti | LeftSemi => case LeftAnti | LeftSemi =>
// LeftSemi and LeftAnti won't ever be bigger than left // LeftSemi and LeftAnti won't ever be bigger than left
left.stats(conf) left.stats
case _ => case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because // Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size. // they could explode the size.
val stats = super.computeStats(conf) val stats = super.computeStats
stats.copy(hints = stats.hints.resetForJoin()) stats.copy(hints = stats.hints.resetForJoin())
} }
if (conf.cboEnabled) { if (conf.cboEnabled) {
JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) JoinEstimation.estimate(this).getOrElse(simpleEstimation)
} else { } else {
simpleEstimation simpleEstimation
} }
@ -523,7 +522,7 @@ case class Range(
override def newInstance(): Range = copy(output = output.map(_.newInstance())) override def newInstance(): Range = copy(output = output.map(_.newInstance()))
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes ) Statistics( sizeInBytes = sizeInBytes )
} }
@ -556,20 +555,20 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg)) child.constraints.union(getAliasedConstraints(nonAgg))
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
def simpleEstimation: Statistics = { def simpleEstimation: Statistics = {
if (groupingExpressions.isEmpty) { if (groupingExpressions.isEmpty) {
Statistics( Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1), rowCount = Some(1),
hints = child.stats(conf).hints) hints = child.stats.hints)
} else { } else {
super.computeStats(conf) super.computeStats
} }
} }
if (conf.cboEnabled) { if (conf.cboEnabled) {
AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation) AggregateEstimation.estimate(this).getOrElse(simpleEstimation)
} else { } else {
simpleEstimation simpleEstimation
} }
@ -672,8 +671,8 @@ case class Expand(
override def references: AttributeSet = override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references)) AttributeSet(projections.flatten.flatMap(_.references))
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length val sizeInBytes = super.computeStats.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes) Statistics(sizeInBytes = sizeInBytes)
} }
@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
case _ => None case _ => None
} }
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int] val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf) val childStats = child.stats
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
// Don't propagate column stats, because we don't know the distribution after a limit operation // Don't propagate column stats, because we don't know the distribution after a limit operation
Statistics( Statistics(
@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
case _ => None case _ => None
} }
} }
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int] val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf) val childStats = child.stats
if (limit == 0) { if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children). // (product of children).
@ -832,9 +831,9 @@ case class Sample(
override def output: Seq[Attribute] = child.output override def output: Seq[Attribute] = child.output
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val ratio = upperBound - lowerBound val ratio = upperBound - lowerBound
val childStats = child.stats(conf) val childStats = child.stats
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) { if (sizeInBytes == 0) {
sizeInBytes = 1 sizeInBytes = 1
@ -898,7 +897,7 @@ case class RepartitionByExpression(
case object OneRowRelation extends LeafNode { case object OneRowRelation extends LeafNode {
override def maxRows: Option[Long] = Some(1) override def maxRows: Option[Long] = Some(1)
override def output: Seq[Attribute] = Nil override def output: Seq[Attribute] = Nil
override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) override def computeStats: Statistics = Statistics(sizeInBytes = 1)
} }
/** A logical plan for `dropDuplicates`. */ /** A logical plan for `dropDuplicates`. */

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf
/** /**
* A general hint for the child that is not yet resolved. This node is generated by the parser and * A general hint for the child that is not yet resolved. This node is generated by the parser and
@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
override lazy val canonicalized: LogicalPlan = child.canonicalized override lazy val canonicalized: LogicalPlan = child.canonicalized
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
val stats = child.stats(conf) val stats = child.stats
stats.copy(hints = hints) stats.copy(hints = hints)
} }
} }

View file

@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
import org.apache.spark.sql.internal.SQLConf
object AggregateEstimation { object AggregateEstimation {
@ -29,13 +28,13 @@ object AggregateEstimation {
* Estimate the number of output rows based on column stats of group-by columns, and propagate * Estimate the number of output rows based on column stats of group-by columns, and propagate
* column stats for aggregate expressions. * column stats for aggregate expressions.
*/ */
def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { def estimate(agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats(conf) val childStats = agg.child.stats
// Check if we have column stats for all group-by columns. // Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e => val colStatsExist = agg.groupingExpressions.forall { e =>
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
} }
if (rowCountsExist(conf, agg.child) && colStatsExist) { if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes // Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns. // the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(

View file

@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, _} import org.apache.spark.sql.types.{DecimalType, _}
object EstimationUtils { object EstimationUtils {
/** Check if each plan has rowCount in its statistics. */ /** Check if each plan has rowCount in its statistics. */
def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = def rowCountsExist(plans: LogicalPlan*): Boolean =
plans.forall(_.stats(conf).rowCount.isDefined) plans.forall(_.stats.rowCount.isDefined)
/** Check if each attribute has column stat in the corresponding statistics. */ /** Check if each attribute has column stat in the corresponding statistics. */
def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = {

View file

@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { case class FilterEstimation(plan: Filter) extends Logging {
private val childStats = plan.child.stats(catalystConf) private val childStats = plan.child.stats
private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)

View file

@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
object JoinEstimation extends Logging { object JoinEstimation extends Logging {
@ -34,12 +33,12 @@ object JoinEstimation extends Logging {
* Estimate statistics after join. Return `None` if the join type is not supported, or we don't * Estimate statistics after join. Return `None` if the join type is not supported, or we don't
* have enough statistics for estimation. * have enough statistics for estimation.
*/ */
def estimate(conf: SQLConf, join: Join): Option[Statistics] = { def estimate(join: Join): Option[Statistics] = {
join.joinType match { join.joinType match {
case Inner | Cross | LeftOuter | RightOuter | FullOuter => case Inner | Cross | LeftOuter | RightOuter | FullOuter =>
InnerOuterEstimation(conf, join).doEstimate() InnerOuterEstimation(join).doEstimate()
case LeftSemi | LeftAnti => case LeftSemi | LeftAnti =>
LeftSemiAntiEstimation(conf, join).doEstimate() LeftSemiAntiEstimation(join).doEstimate()
case _ => case _ =>
logDebug(s"[CBO] Unsupported join type: ${join.joinType}") logDebug(s"[CBO] Unsupported join type: ${join.joinType}")
None None
@ -47,16 +46,16 @@ object JoinEstimation extends Logging {
} }
} }
case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { case class InnerOuterEstimation(join: Join) extends Logging {
private val leftStats = join.left.stats(conf) private val leftStats = join.left.stats
private val rightStats = join.right.stats(conf) private val rightStats = join.right.stats
/** /**
* Estimate output size and number of rows after a join operator, and update output column stats. * Estimate output size and number of rows after a join operator, and update output column stats.
*/ */
def doEstimate(): Option[Statistics] = join match { def doEstimate(): Option[Statistics] = join match {
case _ if !rowCountsExist(conf, join.left, join.right) => case _ if !rowCountsExist(join.left, join.right) =>
None None
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging {
} }
} }
case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { case class LeftSemiAntiEstimation(join: Join) {
def doEstimate(): Option[Statistics] = { def doEstimate(): Option[Statistics] = {
// TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic
// column stats. Now we just propagate the statistics from left side. We should do more // column stats. Now we just propagate the statistics from left side. We should do more
// accurate estimation when advanced stats (e.g. histograms) are available. // accurate estimation when advanced stats (e.g. histograms) are available.
if (rowCountsExist(conf, join.left)) { if (rowCountsExist(join.left)) {
val leftStats = join.left.stats(conf) val leftStats = join.left.stats
// Propagate the original column stats for cartesian product // Propagate the original column stats for cartesian product
val outputRows = leftStats.rowCount.get val outputRows = leftStats.rowCount.get
Some(Statistics( Some(Statistics(

View file

@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics}
import org.apache.spark.sql.internal.SQLConf
object ProjectEstimation { object ProjectEstimation {
import EstimationUtils._ import EstimationUtils._
def estimate(conf: SQLConf, project: Project): Option[Statistics] = { def estimate(project: Project): Option[Statistics] = {
if (rowCountsExist(conf, project.child)) { if (rowCountsExist(project.child)) {
val childStats = project.child.stats(conf) val childStats = project.child.stats
val inputAttrStats = childStats.attributeStats val inputAttrStats = childStats.attributeStats
// Match alias with its child's column stat // Match alias with its child's column stat
val aliasStats = project.expressions.collect { val aliasStats = project.expressions.collect {

View file

@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest {
comparePlans(optimized, expected) comparePlans(optimized, expected)
val broadcastChildren = optimized.collect { val broadcastChildren = optimized.collect {
case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r
} }
assert(broadcastChildren.size == 1) assert(broadcastChildren.size == 1)
} }

View file

@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest {
} }
test("full outer join where neither side is limited and both sides have same statistics") { test("full outer join where neither side is limited and both sides have same statistics") {
assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes) assert(x.stats.sizeInBytes === y.stats.sizeInBytes)
val originalQuery = x.join(y, FullOuter).limit(1) val originalQuery = x.join(y, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze) val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze
@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest {
test("full outer join where neither side is limited and left side has larger statistics") { test("full outer join where neither side is limited and left side has larger statistics") {
val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x)
assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes) assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes)
val originalQuery = xBig.join(y, FullOuter).limit(1) val originalQuery = xBig.join(y, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze) val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze
@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest {
test("full outer join where neither side is limited and right side has larger statistics") { test("full outer join where neither side is limited and right side has larger statistics") {
val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y)
assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes) assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes)
val originalQuery = x.join(yBig, FullOuter).limit(1) val originalQuery = x.join(yBig, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze) val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze

View file

@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
size = Some(4 * (8 + 4)), size = Some(4 * (8 + 4)),
attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
val noGroupAgg = Aggregate(groupingExpressions = Nil, val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) try {
assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
// overhead + count result size val noGroupAgg = Aggregate(groupingExpressions = Nil,
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
assert(noGroupAgg.stats ==
// overhead + count result size
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
val hasGroupAgg = Aggregate(groupingExpressions = attributes, val hasGroupAgg = Aggregate(groupingExpressions = attributes,
aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == assert(hasGroupAgg.stats ==
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
} finally {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
}
} }
private def checkAggStats( private def checkAggStats(
@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(expectedOutputRowCount), rowCount = Some(expectedOutputRowCount),
attributeStats = expectedAttrStats) attributeStats = expectedAttrStats)
assert(testAgg.stats(conf) == expectedStats) assert(testAgg.stats == expectedStats)
} }
} }

View file

@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
val localLimit = LocalLimit(Literal(2), plan) val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan)
// LocalLimit's stats is just its child's stats except column stats // LocalLimit's stats is just its child's stats except column stats
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
} }
test("limit estimation: limit > child's rowCount") { test("limit estimation: limit > child's rowCount") {
val localLimit = LocalLimit(Literal(20), plan) val localLimit = LocalLimit(Literal(20), plan)
val globalLimit = GlobalLimit(Literal(20), plan) val globalLimit = GlobalLimit(Literal(20), plan)
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
// Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
} }
test("limit estimation: limit = 0") { test("limit estimation: limit = 0") {
@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
plan: LogicalPlan, plan: LogicalPlan,
expectedStatsCboOn: Statistics, expectedStatsCboOn: Statistics,
expectedStatsCboOff: Statistics): Unit = { expectedStatsCboOff: Statistics): Unit = {
// Invalidate statistics val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
plan.invalidateStatsCache() try {
assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) // Invalidate statistics
plan.invalidateStatsCache()
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
assert(plan.stats == expectedStatsCboOn)
plan.invalidateStatsCache() plan.invalidateStatsCache()
assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
assert(plan.stats == expectedStatsCboOff)
} finally {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
}
} }
/** Check estimated stats when it's the same whether cbo is turned on or off. */ /** Check estimated stats when it's the same whether cbo is turned on or off. */
@ -135,6 +142,6 @@ private case class DummyLogicalPlan(
cboStats: Statistics) extends LogicalPlan { cboStats: Statistics) extends LogicalPlan {
override def output: Seq[Attribute] = Nil override def output: Seq[Attribute] = Nil
override def children: Seq[LogicalPlan] = Nil override def children: Seq[LogicalPlan] = Nil
override def computeStats(conf: SQLConf): Statistics = override def computeStats: Statistics =
if (conf.cboEnabled) cboStats else defaultStats if (conf.cboEnabled) cboStats else defaultStats
} }

View file

@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(expectedRowCount), rowCount = Some(expectedRowCount),
attributeStats = expectedAttributeMap) attributeStats = expectedAttributeMap)
val filterStats = filter.stats(conf) val filterStats = filter.stats
assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) assert(filterStats.sizeInBytes == expectedStats.sizeInBytes)
assert(filterStats.rowCount == expectedStats.rowCount) assert(filterStats.rowCount == expectedStats.rowCount)
val rowCountValue = filterStats.rowCount.getOrElse(0) val rowCountValue = filterStats.rowCount.getOrElse(0)

View file

@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Keep the column stat from both sides unchanged. // Keep the column stat from both sides unchanged.
attributeStats = AttributeMap( attributeStats = AttributeMap(
Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo)))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("disjoint inner join") { test("disjoint inner join") {
@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 1, sizeInBytes = 1,
rowCount = Some(0), rowCount = Some(0),
attributeStats = AttributeMap(Nil)) attributeStats = AttributeMap(Nil))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("disjoint left outer join") { test("disjoint left outer join") {
@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Null count for right side columns = left row count // Null count for right side columns = left row count
Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5),
nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("disjoint right outer join") { test("disjoint right outer join") {
@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Null count for left side columns = right row count // Null count for left side columns = right row count
Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3),
nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("disjoint full outer join") { test("disjoint full outer join") {
@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3),
nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5),
nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("inner join") { test("inner join") {
@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap( attributeStats = AttributeMap(
Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat,
nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4"))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("inner join with multiple equi-join keys") { test("inner join with multiple equi-join keys") {
@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap( attributeStats = AttributeMap(
Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1,
nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2)))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("left outer join") { test("left outer join") {
@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap( attributeStats = AttributeMap(
Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"),
nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat)))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("right outer join") { test("right outer join") {
@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap( attributeStats = AttributeMap(
Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat,
nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) nameToColInfo("key-1-2"), nameToColInfo("key-2-3"))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("full outer join") { test("full outer join") {
@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Keep the column stat from both sides unchanged. // Keep the column stat from both sides unchanged.
attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"),
nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) nameToColInfo("key-1-2"), nameToColInfo("key-2-3"))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
test("left semi/anti join") { test("left semi/anti join") {
@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 3 * (8 + 4 * 2), sizeInBytes = 3 * (8 + 4 * 2),
rowCount = Some(3), rowCount = Some(3),
attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
} }
@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))),
rowCount = Some(1), rowCount = Some(1),
attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1))))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
} }
} }
@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 1, sizeInBytes = 1,
rowCount = Some(0), rowCount = Some(0),
attributeStats = AttributeMap(Nil)) attributeStats = AttributeMap(Nil))
assert(join.stats(conf) == expectedStats) assert(join.stats == expectedStats)
} }
} }

View file

@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 2 * (8 + 4 + 4), sizeInBytes = 2 * (8 + 4 + 4),
rowCount = Some(2), rowCount = Some(2),
attributeStats = expectedAttrStats) attributeStats = expectedAttrStats)
assert(proj.stats(conf) == expectedStats) assert(proj.stats == expectedStats)
} }
test("project on empty table") { test("project on empty table") {
@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = expectedSize, sizeInBytes = expectedSize,
rowCount = Some(expectedRowCount), rowCount = Some(expectedRowCount),
attributeStats = projectAttrMap) attributeStats = projectAttrMap)
assert(proj.stats(conf) == expectedStats) assert(proj.stats == expectedStats)
} }
} }

View file

@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED}
import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.sql.types.{IntegerType, StringType}
trait StatsEstimationTestBase extends SparkFunSuite { trait StatsEstimationTestBase extends SparkFunSuite {
/** Enable stats estimation based on CBO. */ var originalValue: Boolean = false
protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true)
override def beforeAll(): Unit = {
super.beforeAll()
// Enable stats estimation based on CBO.
originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
}
override def afterAll(): Unit = {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
super.afterAll()
}
def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
// For UTF8String: base + offset + numBytes // For UTF8String: base + offset + numBytes
@ -55,7 +65,7 @@ case class StatsTestPlan(
attributeStats: AttributeMap[ColumnStat], attributeStats: AttributeMap[ColumnStat],
size: Option[BigInt] = None) extends LeafNode { size: Option[BigInt] = None) extends LeafNode {
override def output: Seq[Attribute] = outputList override def output: Seq[Attribute] = outputList
override def computeStats(conf: SQLConf): Statistics = Statistics( override def computeStats: Statistics = Statistics(
// If sizeInBytes is useless in testing, we just use a fake value // If sizeInBytes is useless in testing, we just use a fake value
sizeInBytes = size.getOrElse(Int.MaxValue), sizeInBytes = size.getOrElse(Int.MaxValue),
rowCount = Some(rowCount), rowCount = Some(rowCount),

View file

@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -89,7 +88,7 @@ case class ExternalRDD[T](
override protected def stringArgs: Iterator[Any] = Iterator(output) override protected def stringArgs: Iterator[Any] = Iterator(output)
@transient override def computeStats(conf: SQLConf): Statistics = Statistics( @transient override def computeStats: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size // TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions. // estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
@ -157,7 +156,7 @@ case class LogicalRDD(
override protected def stringArgs: Iterator[Any] = Iterator(output) override protected def stringArgs: Iterator[Any] = Iterator(output)
@transient override def computeStats(conf: SQLConf): Statistics = Statistics( @transient override def computeStats: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size // TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions. // estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)

View file

@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
def stringWithStats: String = { def stringWithStats: String = {
// trigger to compute stats for logical plans // trigger to compute stats for logical plans
optimizedPlan.stats(sparkSession.sessionState.conf) optimizedPlan.stats
// only show optimized logical plan and physical plan // only show optimized logical plan and physical plan
s"""== Optimized Logical Plan == s"""== Optimized Logical Plan ==

View file

@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Matches a plan whose output should be small enough to be used in broadcast join. * Matches a plan whose output should be small enough to be used in broadcast join.
*/ */
private def canBroadcast(plan: LogicalPlan): Boolean = { private def canBroadcast(plan: LogicalPlan): Boolean = {
plan.stats(conf).hints.broadcast || plan.stats.hints.broadcast ||
(plan.stats(conf).sizeInBytes >= 0 && (plan.stats.sizeInBytes >= 0 &&
plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold)
} }
/** /**
@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* dynamic. * dynamic.
*/ */
private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = {
plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
} }
/** /**
@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* use the size of bytes here as estimation. * use the size of bytes here as estimation.
*/ */
private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes
} }
private def canBuildRight(joinType: JoinType): Boolean = joinType match { private def canBuildRight(joinType: JoinType): Boolean = joinType match {
@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Join(left, right, joinType, condition) => case logical.Join(left, right, joinType, condition) =>
val buildSide = val buildSide =
if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) { if (right.stats.sizeInBytes <= left.stats.sizeInBytes) {
BuildRight BuildRight
} else { } else {
BuildLeft BuildLeft

View file

@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator import org.apache.spark.util.LongAccumulator
@ -70,7 +69,7 @@ case class InMemoryRelation(
@transient val partitionStatistics = new PartitionStatistics(output) @transient val partitionStatistics = new PartitionStatistics(output)
override def computeStats(conf: SQLConf): Statistics = { override def computeStats: Statistics = {
if (batchStats.value == 0L) { if (batchStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, no useful statistics information // Underlying columnar RDD hasn't been materialized, no useful statistics information
// available, return the default statistics. // available, return the default statistics.

View file

@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -46,7 +45,7 @@ case class LogicalRelation(
// Only care about relation when canonicalizing. // Only care about relation when canonicalizing.
override def preCanonicalized: LogicalPlan = copy(catalogTable = None) override def preCanonicalized: LogicalPlan = copy(catalogTable = None)
@transient override def computeStats(conf: SQLConf): Statistics = { @transient override def computeStats: Statistics = {
catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse(
Statistics(sizeInBytes = relation.sizeInBytes)) Statistics(sizeInBytes = relation.sizeInBytes))
} }

View file

@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode
private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
override def computeStats(conf: SQLConf): Statistics = override def computeStats: Statistics =
Statistics(sizePerRow * sink.allData.size) Statistics(sizePerRow * sink.allData.size)
} }

View file

@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.table("testData").queryExecution.withCachedData.collect { spark.table("testData").queryExecution.withCachedData.collect {
case cached: InMemoryRelation => case cached: InMemoryRelation =>
val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum
assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) assert(cached.stats.sizeInBytes === actualSizeInBytes)
} }
} }

View file

@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
// instead of Int for avoiding possible overflow. // instead of Int for avoiding possible overflow.
val ds = (0 to 10000).map( i => val ds = (0 to 10000).map( i =>
(i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS()
val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes val sizeInBytes = ds.logicalPlan.stats.sizeInBytes
// sizeInBytes is 2404280404, before the fix, it overflows to a negative number // sizeInBytes is 2404280404, before the fix, it overflows to a negative number
assert(sizeInBytes > 0) assert(sizeInBytes > 0)
} }

View file

@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
setupTestData() setupTestData()
def statisticSizeInByte(df: DataFrame): BigInt = { def statisticSizeInByte(df: DataFrame): BigInt = {
df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes df.queryExecution.optimizedPlan.stats.sizeInBytes
} }
test("equi-join is hash-join") { test("equi-join is hash-join") {

View file

@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
val df = df1.join(df2, Seq("k"), "left") val df = df1.join(df2, Seq("k"), "left")
val sizes = df.queryExecution.analyzed.collect { case g: Join => val sizes = df.queryExecution.analyzed.collect { case g: Join =>
g.stats(conf).sizeInBytes g.stats.sizeInBytes
} }
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > assert(df.queryExecution.analyzed.stats.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold) spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold) spark.sessionState.conf.autoBroadcastJoinThreshold)
} }
@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
test("SPARK-18856: non-empty partitioned table should not report zero size") { test("SPARK-18856: non-empty partitioned table should not report zero size") {
withTable("ds_tbl", "hive_tbl") { withTable("ds_tbl", "hive_tbl") {
spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl")
val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats
assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.")
if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") {
sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)")
sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1")
val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats
assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.")
} }
} }
@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat))
// Check relation statistics // Check relation statistics
assert(relation.stats(conf).sizeInBytes == 0) assert(relation.stats.sizeInBytes == 0)
assert(relation.stats(conf).rowCount == Some(0)) assert(relation.stats.rowCount == Some(0))
assert(relation.stats(conf).attributeStats.size == 1) assert(relation.stats.attributeStats.size == 1)
val (attribute, colStat) = relation.stats(conf).attributeStats.head val (attribute, colStat) = relation.stats.attributeStats.head
assert(attribute.name == "c1") assert(attribute.name == "c1")
assert(colStat == emptyColStat) assert(colStat == emptyColStat)
} }

View file

@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
.toDF().createOrReplaceTempView("sizeTst") .toDF().createOrReplaceTempView("sizeTst")
spark.catalog.cacheTable("sizeTst") spark.catalog.cacheTable("sizeTst")
assert( assert(
spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes >
spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
} }

View file

@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {
}) })
val totalSize = allFiles.map(_.length()).sum val totalSize = allFiles.map(_.length()).sum
val df = spark.read.parquet(dir.toString) val df = spark.read.parquet(dir.toString)
assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize))
} }
} }
} }

View file

@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
// Before adding data, check output // Before adding data, check output
checkAnswer(sink.allData, Seq.empty) checkAnswer(sink.allData, Seq.empty)
assert(plan.stats(sqlConf).sizeInBytes === 0) assert(plan.stats.sizeInBytes === 0)
sink.addBatch(0, 1 to 3) sink.addBatch(0, 1 to 3)
plan.invalidateStatsCache() plan.invalidateStatsCache()
assert(plan.stats(sqlConf).sizeInBytes === 12) assert(plan.stats.sizeInBytes === 12)
sink.addBatch(1, 4 to 6) sink.addBatch(1, 4 to 6)
plan.invalidateStatsCache() plan.invalidateStatsCache()
assert(plan.stats(sqlConf).sizeInBytes === 24) assert(plan.stats.sizeInBytes === 24)
} }
ignore("stress test") { ignore("stress test") {

View file

@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.sql.internal.SQLConf
/** /**
* A collection of sample data used in SQL tests. * A collection of sample data used in SQL tests.
@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf
private[sql] trait SQLTestData { self => private[sql] trait SQLTestData { self =>
protected def spark: SparkSession protected def spark: SparkSession
protected def sqlConf: SQLConf = spark.sessionState.conf
// Helper object to import SQL implicits without a concrete SQLContext // Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits { private object internalImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self.spark.sqlContext protected override def _sqlContext: SQLContext = self.spark.sqlContext

View file

@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
Some(partitionSchema)) Some(partitionSchema))
val logicalRelation = cached.getOrElse { val logicalRelation = cached.getOrElse {
val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong val sizeInBytes = relation.stats.sizeInBytes.toLong
val fileIndex = { val fileIndex = {
val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes)
if (lazyPruningEnabled) { if (lazyPruningEnabled) {

View file

@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0")
assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0")
val sizeInBytes = relation.stats(conf).sizeInBytes val sizeInBytes = relation.stats.sizeInBytes
assert(sizeInBytes === BigInt(file1.length() + file2.length())) assert(sizeInBytes === BigInt(file1.length() + file2.length()))
} }
} }
@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
test("analyze Hive serde tables") { test("analyze Hive serde tables") {
def queryTotalSize(tableName: String): BigInt = def queryTotalSize(tableName: String): BigInt =
spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes
// Non-partitioned table // Non-partitioned table
sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect()
@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
test("estimates the size of a test Hive serde tables") { test("estimates the size of a test Hive serde tables") {
val df = sql("""SELECT * FROM src""") val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect { val sizes = df.queryExecution.analyzed.collect {
case relation: CatalogRelation => relation.stats(conf).sizeInBytes case relation: CatalogRelation => relation.stats.sizeInBytes
} }
assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizes(0).equals(BigInt(5812)), assert(sizes(0).equals(BigInt(5812)),
@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
// Assert src has a size smaller than the threshold. // Assert src has a size smaller than the threshold.
val sizes = df.queryExecution.analyzed.collect { val sizes = df.queryExecution.analyzed.collect {
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes
} }
assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold
&& sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold,
@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
// Assert src has a size smaller than the threshold. // Assert src has a size smaller than the threshold.
val sizes = df.queryExecution.analyzed.collect { val sizes = df.queryExecution.analyzed.collect {
case relation: CatalogRelation => relation.stats(conf).sizeInBytes case relation: CatalogRelation => relation.stats.sizeInBytes
} }
assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold
&& sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold,

View file

@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
case relation: LogicalRelation => relation case relation: LogicalRelation => relation
} }
assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
val size2 = relations(0).computeStats(conf).sizeInBytes val size2 = relations(0).computeStats.sizeInBytes
assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
assert(size2 < tableStats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes)
} }