[SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan

## What changes were proposed in this pull request?

After wiring `SQLConf` in logical plan ([PR 18299](https://github.com/apache/spark/pull/18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`.

## How was this patch tested?

Covered by existing tests, plus some modified existing tests.

Author: wangzhenhua <wangzhenhua@huawei.com>
Author: Zhenhua Wang <wzh_zju@163.com>

Closes #18391 from wzhfy/removeConf.
This commit is contained in:
wangzhenhua 2017-06-23 10:33:53 -07:00 committed by Xiao Li
parent 07479b3cfb
commit b803b66a81
38 changed files with 176 additions and 171 deletions

View file

@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
@ -436,7 +435,7 @@ case class CatalogRelation(
createTime = -1
))
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
// For data source tables, we will create a `LogicalRelation` and won't call this method, for
// hive serde tables, we will always generate a statistics.
// TODO: unify the table stats generation.

View file

@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
// Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated.
if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
items.forall(_.stats(conf).rowCount.isDefined)) {
items.forall(_.stats.rowCount.isDefined)) {
JoinReorderDP.search(conf, items, conditions, output)
} else {
plan
@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
/** Get the cost of the root node of this plan tree. */
def rootCost(conf: SQLConf): Cost = {
if (itemIds.size > 1) {
val rootStats = plan.stats(conf)
val rootStats = plan.stats
Cost(rootStats.rowCount.get, rootStats.sizeInBytes)
} else {
// If the plan is a leaf item, it has zero cost.

View file

@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
case FullOuter =>
(left.maxRows, right.maxRows) match {
case (None, None) =>
if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) {
if (left.stats.sizeInBytes >= right.stats.sizeInBytes) {
join.copy(left = maybePushLimit(exp, left))
} else {
join.copy(right = maybePushLimit(exp, right))

View file

@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
// Find if the input plans are eligible for star join detection.
// An eligible plan is a base table access with valid statistics.
val foundEligibleJoin = input.forall {
case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true
case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true
case _ => false
}
@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan)
leafCol match {
case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf)
val stats = t.stats
stats.rowCount match {
case Some(rowCount) if rowCount >= 0 =>
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan)
leafCol match {
case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf)
val stats = t.stats
stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
case None => false
}
@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
*/
private def getTableAccessCardinality(
input: LogicalPlan): Option[BigInt] = input match {
case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined =>
if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
Option(input.stats(conf).rowCount.get)
case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined =>
if (conf.cboEnabled && input.stats.rowCount.isDefined) {
Option(input.stats.rowCount.get)
} else {
Option(t.stats(conf).rowCount.get)
Option(t.stats.rowCount.get)
}
case _ => None
}

View file

@ -21,7 +21,6 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {
@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
}
}
override def computeStats(conf: SQLConf): Statistics =
override def computeStats: Statistics =
Statistics(sizeInBytes =
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)

View file

@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
* first time. If the configuration changes, the cache can be invalidated by calling
* [[invalidateStatsCache()]].
*/
final def stats(conf: SQLConf): Statistics = statsCache.getOrElse {
statsCache = Some(computeStats(conf))
final def stats: Statistics = statsCache.getOrElse {
statsCache = Some(computeStats)
statsCache.get
}
@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
*
* [[LeafNode]]s must override this.
*/
protected def computeStats(conf: SQLConf): Statistics = {
protected def computeStats: Statistics = {
if (children.isEmpty) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product)
Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
}
override def verboseStringWithSuffix: String = {
@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan {
override protected def validConstraints: Set[Expression] = child.constraints
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
// Assume there will be the same number of rows as child has.
var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize
var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan {
}
// Don't propagate rowCount and attributeStats, since they are not estimated here.
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
}
}

View file

@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.RandomSampler
@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (conf.cboEnabled) {
ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf))
ProjectEstimation.estimate(this).getOrElse(super.computeStats)
} else {
super.computeStats(conf)
super.computeStats
}
}
}
@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
child.constraints.union(predicates.toSet)
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (conf.cboEnabled) {
FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
FilterEstimation(this).estimate.getOrElse(super.computeStats)
} else {
super.computeStats(conf)
super.computeStats
}
}
}
@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
}
}
override def computeStats(conf: SQLConf): Statistics = {
val leftSize = left.stats(conf).sizeInBytes
val rightSize = right.stats(conf).sizeInBytes
override def computeStats: Statistics = {
val leftSize = left.stats.sizeInBytes
val rightSize = right.stats.sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
Statistics(
sizeInBytes = sizeInBytes,
hints = left.stats(conf).hints.resetForJoin())
hints = left.stats.hints.resetForJoin())
}
}
@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override protected def validConstraints: Set[Expression] = leftConstraints
override def computeStats(conf: SQLConf): Statistics = {
left.stats(conf).copy()
override def computeStats: Statistics = {
left.stats.copy()
}
}
@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
children.length > 1 && childrenResolved && allChildrenCompatible
}
override def computeStats(conf: SQLConf): Statistics = {
val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum
override def computeStats: Statistics = {
val sizeInBytes = children.map(_.stats.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
}
@ -357,20 +356,20 @@ case class Join(
case _ => resolvedExceptNatural
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
def simpleEstimation: Statistics = joinType match {
case LeftAnti | LeftSemi =>
// LeftSemi and LeftAnti won't ever be bigger than left
left.stats(conf)
left.stats
case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size.
val stats = super.computeStats(conf)
val stats = super.computeStats
stats.copy(hints = stats.hints.resetForJoin())
}
if (conf.cboEnabled) {
JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation)
JoinEstimation.estimate(this).getOrElse(simpleEstimation)
} else {
simpleEstimation
}
@ -523,7 +522,7 @@ case class Range(
override def newInstance(): Range = copy(output = output.map(_.newInstance()))
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes )
}
@ -556,20 +555,20 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg))
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
def simpleEstimation: Statistics = {
if (groupingExpressions.isEmpty) {
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1),
hints = child.stats(conf).hints)
hints = child.stats.hints)
} else {
super.computeStats(conf)
super.computeStats
}
}
if (conf.cboEnabled) {
AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation)
AggregateEstimation.estimate(this).getOrElse(simpleEstimation)
} else {
simpleEstimation
}
@ -672,8 +671,8 @@ case class Expand(
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))
override def computeStats(conf: SQLConf): Statistics = {
val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length
override def computeStats: Statistics = {
val sizeInBytes = super.computeStats.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
case _ => None
}
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf)
val childStats = child.stats
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
// Don't propagate column stats, because we don't know the distribution after a limit operation
Statistics(
@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
case _ => None
}
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf)
val childStats = child.stats
if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
@ -832,9 +831,9 @@ case class Sample(
override def output: Seq[Attribute] = child.output
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val ratio = upperBound - lowerBound
val childStats = child.stats(conf)
val childStats = child.stats
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
@ -898,7 +897,7 @@ case class RepartitionByExpression(
case object OneRowRelation extends LeafNode {
override def maxRows: Option[Long] = Some(1)
override def output: Seq[Attribute] = Nil
override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1)
override def computeStats: Statistics = Statistics(sizeInBytes = 1)
}
/** A logical plan for `dropDuplicates`. */

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf
/**
* A general hint for the child that is not yet resolved. This node is generated by the parser and
@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
override lazy val canonicalized: LogicalPlan = child.canonicalized
override def computeStats(conf: SQLConf): Statistics = {
val stats = child.stats(conf)
override def computeStats: Statistics = {
val stats = child.stats
stats.copy(hints = hints)
}
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
import org.apache.spark.sql.internal.SQLConf
object AggregateEstimation {
@ -29,13 +28,13 @@ object AggregateEstimation {
* Estimate the number of output rows based on column stats of group-by columns, and propagate
* column stats for aggregate expressions.
*/
def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats(conf)
def estimate(agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats
// Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e =>
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
}
if (rowCountsExist(conf, agg.child) && colStatsExist) {
if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(

View file

@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, _}
object EstimationUtils {
/** Check if each plan has rowCount in its statistics. */
def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean =
plans.forall(_.stats(conf).rowCount.isDefined)
def rowCountsExist(plans: LogicalPlan*): Boolean =
plans.forall(_.stats.rowCount.isDefined)
/** Check if each attribute has column stat in the corresponding statistics. */
def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = {

View file

@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging {
case class FilterEstimation(plan: Filter) extends Logging {
private val childStats = plan.child.stats(catalystConf)
private val childStats = plan.child.stats
private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)

View file

@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
object JoinEstimation extends Logging {
@ -34,12 +33,12 @@ object JoinEstimation extends Logging {
* Estimate statistics after join. Return `None` if the join type is not supported, or we don't
* have enough statistics for estimation.
*/
def estimate(conf: SQLConf, join: Join): Option[Statistics] = {
def estimate(join: Join): Option[Statistics] = {
join.joinType match {
case Inner | Cross | LeftOuter | RightOuter | FullOuter =>
InnerOuterEstimation(conf, join).doEstimate()
InnerOuterEstimation(join).doEstimate()
case LeftSemi | LeftAnti =>
LeftSemiAntiEstimation(conf, join).doEstimate()
LeftSemiAntiEstimation(join).doEstimate()
case _ =>
logDebug(s"[CBO] Unsupported join type: ${join.joinType}")
None
@ -47,16 +46,16 @@ object JoinEstimation extends Logging {
}
}
case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging {
case class InnerOuterEstimation(join: Join) extends Logging {
private val leftStats = join.left.stats(conf)
private val rightStats = join.right.stats(conf)
private val leftStats = join.left.stats
private val rightStats = join.right.stats
/**
* Estimate output size and number of rows after a join operator, and update output column stats.
*/
def doEstimate(): Option[Statistics] = join match {
case _ if !rowCountsExist(conf, join.left, join.right) =>
case _ if !rowCountsExist(join.left, join.right) =>
None
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging {
}
}
case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) {
case class LeftSemiAntiEstimation(join: Join) {
def doEstimate(): Option[Statistics] = {
// TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic
// column stats. Now we just propagate the statistics from left side. We should do more
// accurate estimation when advanced stats (e.g. histograms) are available.
if (rowCountsExist(conf, join.left)) {
val leftStats = join.left.stats(conf)
if (rowCountsExist(join.left)) {
val leftStats = join.left.stats
// Propagate the original column stats for cartesian product
val outputRows = leftStats.rowCount.get
Some(Statistics(

View file

@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics}
import org.apache.spark.sql.internal.SQLConf
object ProjectEstimation {
import EstimationUtils._
def estimate(conf: SQLConf, project: Project): Option[Statistics] = {
if (rowCountsExist(conf, project.child)) {
val childStats = project.child.stats(conf)
def estimate(project: Project): Option[Statistics] = {
if (rowCountsExist(project.child)) {
val childStats = project.child.stats
val inputAttrStats = childStats.attributeStats
// Match alias with its child's column stat
val aliasStats = project.expressions.collect {

View file

@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest {
comparePlans(optimized, expected)
val broadcastChildren = optimized.collect {
case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r
case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r
}
assert(broadcastChildren.size == 1)
}

View file

@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest {
}
test("full outer join where neither side is limited and both sides have same statistics") {
assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes)
assert(x.stats.sizeInBytes === y.stats.sizeInBytes)
val originalQuery = x.join(y, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze
@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest {
test("full outer join where neither side is limited and left side has larger statistics") {
val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x)
assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes)
assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes)
val originalQuery = xBig.join(y, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze
@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest {
test("full outer join where neither side is limited and right side has larger statistics") {
val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y)
assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes)
assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes)
val originalQuery = x.join(yBig, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze

View file

@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
size = Some(4 * (8 + 4)),
attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
val noGroupAgg = Aggregate(groupingExpressions = Nil,
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) ==
// overhead + count result size
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
try {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
val noGroupAgg = Aggregate(groupingExpressions = Nil,
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
assert(noGroupAgg.stats ==
// overhead + count result size
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
val hasGroupAgg = Aggregate(groupingExpressions = attributes,
aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) ==
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
val hasGroupAgg = Aggregate(groupingExpressions = attributes,
aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
assert(hasGroupAgg.stats ==
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
} finally {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
}
}
private def checkAggStats(
@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(expectedOutputRowCount),
attributeStats = expectedAttrStats)
assert(testAgg.stats(conf) == expectedStats)
assert(testAgg.stats == expectedStats)
}
}

View file

@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
// LocalLimit's stats is just its child's stats except column stats
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
}
test("limit estimation: limit > child's rowCount") {
val localLimit = LocalLimit(Literal(20), plan)
val globalLimit = GlobalLimit(Literal(20), plan)
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
// Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
}
test("limit estimation: limit = 0") {
@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
plan: LogicalPlan,
expectedStatsCboOn: Statistics,
expectedStatsCboOff: Statistics): Unit = {
// Invalidate statistics
plan.invalidateStatsCache()
assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn)
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
try {
// Invalidate statistics
plan.invalidateStatsCache()
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
assert(plan.stats == expectedStatsCboOn)
plan.invalidateStatsCache()
assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff)
plan.invalidateStatsCache()
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
assert(plan.stats == expectedStatsCboOff)
} finally {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
}
}
/** Check estimated stats when it's the same whether cbo is turned on or off. */
@ -135,6 +142,6 @@ private case class DummyLogicalPlan(
cboStats: Statistics) extends LogicalPlan {
override def output: Seq[Attribute] = Nil
override def children: Seq[LogicalPlan] = Nil
override def computeStats(conf: SQLConf): Statistics =
override def computeStats: Statistics =
if (conf.cboEnabled) cboStats else defaultStats
}

View file

@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(expectedRowCount),
attributeStats = expectedAttributeMap)
val filterStats = filter.stats(conf)
val filterStats = filter.stats
assert(filterStats.sizeInBytes == expectedStats.sizeInBytes)
assert(filterStats.rowCount == expectedStats.rowCount)
val rowCountValue = filterStats.rowCount.getOrElse(0)

View file

@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Keep the column stat from both sides unchanged.
attributeStats = AttributeMap(
Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo)))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("disjoint inner join") {
@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 1,
rowCount = Some(0),
attributeStats = AttributeMap(Nil))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("disjoint left outer join") {
@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Null count for right side columns = left row count
Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5),
nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("disjoint right outer join") {
@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Null count for left side columns = right row count
Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3),
nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("disjoint full outer join") {
@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3),
nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5),
nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("inner join") {
@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap(
Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat,
nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4"))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("inner join with multiple equi-join keys") {
@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap(
Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1,
nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2)))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("left outer join") {
@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap(
Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"),
nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat)))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("right outer join") {
@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap(
Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat,
nameToColInfo("key-1-2"), nameToColInfo("key-2-3"))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("full outer join") {
@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// Keep the column stat from both sides unchanged.
attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"),
nameToColInfo("key-1-2"), nameToColInfo("key-2-3"))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
test("left semi/anti join") {
@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 3 * (8 + 4 * 2),
rowCount = Some(3),
attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
}
@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))),
rowCount = Some(1),
attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1))))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
}
}
@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 1,
rowCount = Some(0),
attributeStats = AttributeMap(Nil))
assert(join.stats(conf) == expectedStats)
assert(join.stats == expectedStats)
}
}

View file

@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 2 * (8 + 4 + 4),
rowCount = Some(2),
attributeStats = expectedAttrStats)
assert(proj.stats(conf) == expectedStats)
assert(proj.stats == expectedStats)
}
test("project on empty table") {
@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = expectedSize,
rowCount = Some(expectedRowCount),
attributeStats = projectAttrMap)
assert(proj.stats(conf) == expectedStats)
assert(proj.stats == expectedStats)
}
}

View file

@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED}
import org.apache.spark.sql.types.{IntegerType, StringType}
trait StatsEstimationTestBase extends SparkFunSuite {
/** Enable stats estimation based on CBO. */
protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true)
var originalValue: Boolean = false
override def beforeAll(): Unit = {
super.beforeAll()
// Enable stats estimation based on CBO.
originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
}
override def afterAll(): Unit = {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
super.afterAll()
}
def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
// For UTF8String: base + offset + numBytes
@ -55,7 +65,7 @@ case class StatsTestPlan(
attributeStats: AttributeMap[ColumnStat],
size: Option[BigInt] = None) extends LeafNode {
override def output: Seq[Attribute] = outputList
override def computeStats(conf: SQLConf): Statistics = Statistics(
override def computeStats: Statistics = Statistics(
// If sizeInBytes is useless in testing, we just use a fake value
sizeInBytes = size.getOrElse(Int.MaxValue),
rowCount = Some(rowCount),

View file

@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
@ -89,7 +88,7 @@ case class ExternalRDD[T](
override protected def stringArgs: Iterator[Any] = Iterator(output)
@transient override def computeStats(conf: SQLConf): Statistics = Statistics(
@transient override def computeStats: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
@ -157,7 +156,7 @@ case class LogicalRDD(
override protected def stringArgs: Iterator[Any] = Iterator(output)
@transient override def computeStats(conf: SQLConf): Statistics = Statistics(
@transient override def computeStats: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)

View file

@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
def stringWithStats: String = {
// trigger to compute stats for logical plans
optimizedPlan.stats(sparkSession.sessionState.conf)
optimizedPlan.stats
// only show optimized logical plan and physical plan
s"""== Optimized Logical Plan ==

View file

@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Matches a plan whose output should be small enough to be used in broadcast join.
*/
private def canBroadcast(plan: LogicalPlan): Boolean = {
plan.stats(conf).hints.broadcast ||
(plan.stats(conf).sizeInBytes >= 0 &&
plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold)
plan.stats.hints.broadcast ||
(plan.stats.sizeInBytes >= 0 &&
plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold)
}
/**
@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* dynamic.
*/
private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = {
plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
}
/**
@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* use the size of bytes here as estimation.
*/
private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes
a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes
}
private def canBuildRight(joinType: JoinType): Boolean = joinType match {
@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) {
if (right.stats.sizeInBytes <= left.stats.sizeInBytes) {
BuildRight
} else {
BuildLeft

View file

@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator
@ -70,7 +69,7 @@ case class InMemoryRelation(
@transient val partitionStatistics = new PartitionStatistics(output)
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (batchStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, no useful statistics information
// available, return the default statistics.

View file

@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.util.Utils
@ -46,7 +45,7 @@ case class LogicalRelation(
// Only care about relation when canonicalizing.
override def preCanonicalized: LogicalPlan = copy(catalogTable = None)
@transient override def computeStats(conf: SQLConf): Statistics = {
@transient override def computeStats: Statistics = {
catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse(
Statistics(sizeInBytes = relation.sizeInBytes))
}

View file

@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode
private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
override def computeStats(conf: SQLConf): Statistics =
override def computeStats: Statistics =
Statistics(sizePerRow * sink.allData.size)
}

View file

@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.table("testData").queryExecution.withCachedData.collect {
case cached: InMemoryRelation =>
val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum
assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes)
assert(cached.stats.sizeInBytes === actualSizeInBytes)
}
}

View file

@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
// instead of Int for avoiding possible overflow.
val ds = (0 to 10000).map( i =>
(i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS()
val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes
val sizeInBytes = ds.logicalPlan.stats.sizeInBytes
// sizeInBytes is 2404280404, before the fix, it overflows to a negative number
assert(sizeInBytes > 0)
}

View file

@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
setupTestData()
def statisticSizeInByte(df: DataFrame): BigInt = {
df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes
df.queryExecution.optimizedPlan.stats.sizeInBytes
}
test("equi-join is hash-join") {

View file

@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
val df = df1.join(df2, Seq("k"), "left")
val sizes = df.queryExecution.analyzed.collect { case g: Join =>
g.stats(conf).sizeInBytes
g.stats.sizeInBytes
}
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.stats(conf).sizeInBytes >
assert(df.queryExecution.analyzed.stats.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes >
assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes >
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
test("SPARK-18856: non-empty partitioned table should not report zero size") {
withTable("ds_tbl", "hive_tbl") {
spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl")
val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf)
val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats
assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.")
if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") {
sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)")
sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1")
val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf)
val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats
assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.")
}
}
@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat))
// Check relation statistics
assert(relation.stats(conf).sizeInBytes == 0)
assert(relation.stats(conf).rowCount == Some(0))
assert(relation.stats(conf).attributeStats.size == 1)
val (attribute, colStat) = relation.stats(conf).attributeStats.head
assert(relation.stats.sizeInBytes == 0)
assert(relation.stats.rowCount == Some(0))
assert(relation.stats.attributeStats.size == 1)
val (attribute, colStat) = relation.stats.attributeStats.head
assert(attribute.name == "c1")
assert(colStat == emptyColStat)
}

View file

@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
.toDF().createOrReplaceTempView("sizeTst")
spark.catalog.cacheTable("sizeTst")
assert(
spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes >
spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes >
spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
}

View file

@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {
})
val totalSize = allFiles.map(_.length()).sum
val df = spark.read.parquet(dir.toString)
assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize))
assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize))
}
}
}

View file

@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
// Before adding data, check output
checkAnswer(sink.allData, Seq.empty)
assert(plan.stats(sqlConf).sizeInBytes === 0)
assert(plan.stats.sizeInBytes === 0)
sink.addBatch(0, 1 to 3)
plan.invalidateStatsCache()
assert(plan.stats(sqlConf).sizeInBytes === 12)
assert(plan.stats.sizeInBytes === 12)
sink.addBatch(1, 4 to 6)
plan.invalidateStatsCache()
assert(plan.stats(sqlConf).sizeInBytes === 24)
assert(plan.stats.sizeInBytes === 24)
}
ignore("stress test") {

View file

@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.sql.internal.SQLConf
/**
* A collection of sample data used in SQL tests.
@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf
private[sql] trait SQLTestData { self =>
protected def spark: SparkSession
protected def sqlConf: SQLConf = spark.sessionState.conf
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self.spark.sqlContext

View file

@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
Some(partitionSchema))
val logicalRelation = cached.getOrElse {
val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong
val sizeInBytes = relation.stats.sizeInBytes.toLong
val fileIndex = {
val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes)
if (lazyPruningEnabled) {

View file

@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0")
assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0")
val sizeInBytes = relation.stats(conf).sizeInBytes
val sizeInBytes = relation.stats.sizeInBytes
assert(sizeInBytes === BigInt(file1.length() + file2.length()))
}
}
@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
test("analyze Hive serde tables") {
def queryTotalSize(tableName: String): BigInt =
spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes
spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes
// Non-partitioned table
sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect()
@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
test("estimates the size of a test Hive serde tables") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect {
case relation: CatalogRelation => relation.stats(conf).sizeInBytes
case relation: CatalogRelation => relation.stats.sizeInBytes
}
assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizes(0).equals(BigInt(5812)),
@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
// Assert src has a size smaller than the threshold.
val sizes = df.queryExecution.analyzed.collect {
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes
}
assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold
&& sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold,
@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
// Assert src has a size smaller than the threshold.
val sizes = df.queryExecution.analyzed.collect {
case relation: CatalogRelation => relation.stats(conf).sizeInBytes
case relation: CatalogRelation => relation.stats.sizeInBytes
}
assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold
&& sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold,

View file

@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
case relation: LogicalRelation => relation
}
assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
val size2 = relations(0).computeStats(conf).sizeInBytes
val size2 = relations(0).computeStats.sizeInBytes
assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
assert(size2 < tableStats.get.sizeInBytes)
}