[SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan
## What changes were proposed in this pull request? After wiring `SQLConf` in logical plan ([PR 18299](https://github.com/apache/spark/pull/18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`. ## How was this patch tested? Covered by existing tests, plus some modified existing tests. Author: wangzhenhua <wangzhenhua@huawei.com> Author: Zhenhua Wang <wzh_zju@163.com> Closes #18391 from wzhfy/removeConf.
This commit is contained in:
parent
07479b3cfb
commit
b803b66a81
|
@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.catalyst.util.quoteIdentifier
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
||||
|
@ -436,7 +435,7 @@ case class CatalogRelation(
|
|||
createTime = -1
|
||||
))
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
// For data source tables, we will create a `LogicalRelation` and won't call this method, for
|
||||
// hive serde tables, we will always generate a statistics.
|
||||
// TODO: unify the table stats generation.
|
||||
|
|
|
@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
|
|||
// Do reordering if the number of items is appropriate and join conditions exist.
|
||||
// We also need to check if costs of all items can be evaluated.
|
||||
if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
|
||||
items.forall(_.stats(conf).rowCount.isDefined)) {
|
||||
items.forall(_.stats.rowCount.isDefined)) {
|
||||
JoinReorderDP.search(conf, items, conditions, output)
|
||||
} else {
|
||||
plan
|
||||
|
@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
|
|||
/** Get the cost of the root node of this plan tree. */
|
||||
def rootCost(conf: SQLConf): Cost = {
|
||||
if (itemIds.size > 1) {
|
||||
val rootStats = plan.stats(conf)
|
||||
val rootStats = plan.stats
|
||||
Cost(rootStats.rowCount.get, rootStats.sizeInBytes)
|
||||
} else {
|
||||
// If the plan is a leaf item, it has zero cost.
|
||||
|
|
|
@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
|
|||
case FullOuter =>
|
||||
(left.maxRows, right.maxRows) match {
|
||||
case (None, None) =>
|
||||
if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) {
|
||||
if (left.stats.sizeInBytes >= right.stats.sizeInBytes) {
|
||||
join.copy(left = maybePushLimit(exp, left))
|
||||
} else {
|
||||
join.copy(right = maybePushLimit(exp, right))
|
||||
|
|
|
@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
|
|||
// Find if the input plans are eligible for star join detection.
|
||||
// An eligible plan is a base table access with valid statistics.
|
||||
val foundEligibleJoin = input.forall {
|
||||
case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true
|
||||
case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
|
@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
|
|||
val leafCol = findLeafNodeCol(column, plan)
|
||||
leafCol match {
|
||||
case Some(col) if t.outputSet.contains(col) =>
|
||||
val stats = t.stats(conf)
|
||||
val stats = t.stats
|
||||
stats.rowCount match {
|
||||
case Some(rowCount) if rowCount >= 0 =>
|
||||
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
|
||||
|
@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
|
|||
val leafCol = findLeafNodeCol(column, plan)
|
||||
leafCol match {
|
||||
case Some(col) if t.outputSet.contains(col) =>
|
||||
val stats = t.stats(conf)
|
||||
val stats = t.stats
|
||||
stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
|
||||
case None => false
|
||||
}
|
||||
|
@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
|
|||
*/
|
||||
private def getTableAccessCardinality(
|
||||
input: LogicalPlan): Option[BigInt] = input match {
|
||||
case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined =>
|
||||
if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
|
||||
Option(input.stats(conf).rowCount.get)
|
||||
case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined =>
|
||||
if (conf.cboEnabled && input.stats.rowCount.isDefined) {
|
||||
Option(input.stats.rowCount.get)
|
||||
} else {
|
||||
Option(t.stats(conf).rowCount.get)
|
||||
Option(t.stats.rowCount.get)
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.apache.spark.sql.Row
|
|||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.analysis
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
|
||||
object LocalRelation {
|
||||
|
@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
|
|||
}
|
||||
}
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics =
|
||||
override def computeStats: Statistics =
|
||||
Statistics(sizeInBytes =
|
||||
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
||||
|
@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
|
|||
* first time. If the configuration changes, the cache can be invalidated by calling
|
||||
* [[invalidateStatsCache()]].
|
||||
*/
|
||||
final def stats(conf: SQLConf): Statistics = statsCache.getOrElse {
|
||||
statsCache = Some(computeStats(conf))
|
||||
final def stats: Statistics = statsCache.getOrElse {
|
||||
statsCache = Some(computeStats)
|
||||
statsCache.get
|
||||
}
|
||||
|
||||
|
@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
|
|||
*
|
||||
* [[LeafNode]]s must override this.
|
||||
*/
|
||||
protected def computeStats(conf: SQLConf): Statistics = {
|
||||
protected def computeStats: Statistics = {
|
||||
if (children.isEmpty) {
|
||||
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
|
||||
}
|
||||
Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product)
|
||||
Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
|
||||
}
|
||||
|
||||
override def verboseStringWithSuffix: String = {
|
||||
|
@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan {
|
|||
|
||||
override protected def validConstraints: Set[Expression] = child.constraints
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
// There should be some overhead in Row object, the size should not be zero when there is
|
||||
// no columns, this help to prevent divide-by-zero error.
|
||||
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
|
||||
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
|
||||
// Assume there will be the same number of rows as child has.
|
||||
var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize
|
||||
var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
|
||||
if (sizeInBytes == 0) {
|
||||
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
|
||||
// (product of children).
|
||||
|
@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan {
|
|||
}
|
||||
|
||||
// Don't propagate rowCount and attributeStats, since they are not estimated here.
|
||||
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
|
||||
Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.random.RandomSampler
|
||||
|
@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
|
|||
override def validConstraints: Set[Expression] =
|
||||
child.constraints.union(getAliasedConstraints(projectList))
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
if (conf.cboEnabled) {
|
||||
ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf))
|
||||
ProjectEstimation.estimate(this).getOrElse(super.computeStats)
|
||||
} else {
|
||||
super.computeStats(conf)
|
||||
super.computeStats
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
|
|||
child.constraints.union(predicates.toSet)
|
||||
}
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
if (conf.cboEnabled) {
|
||||
FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
|
||||
FilterEstimation(this).estimate.getOrElse(super.computeStats)
|
||||
} else {
|
||||
super.computeStats(conf)
|
||||
super.computeStats
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
|
|||
}
|
||||
}
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
val leftSize = left.stats(conf).sizeInBytes
|
||||
val rightSize = right.stats(conf).sizeInBytes
|
||||
override def computeStats: Statistics = {
|
||||
val leftSize = left.stats.sizeInBytes
|
||||
val rightSize = right.stats.sizeInBytes
|
||||
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
|
||||
Statistics(
|
||||
sizeInBytes = sizeInBytes,
|
||||
hints = left.stats(conf).hints.resetForJoin())
|
||||
hints = left.stats.hints.resetForJoin())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
|
|||
|
||||
override protected def validConstraints: Set[Expression] = leftConstraints
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
left.stats(conf).copy()
|
||||
override def computeStats: Statistics = {
|
||||
left.stats.copy()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
|
|||
children.length > 1 && childrenResolved && allChildrenCompatible
|
||||
}
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum
|
||||
override def computeStats: Statistics = {
|
||||
val sizeInBytes = children.map(_.stats.sizeInBytes).sum
|
||||
Statistics(sizeInBytes = sizeInBytes)
|
||||
}
|
||||
|
||||
|
@ -357,20 +356,20 @@ case class Join(
|
|||
case _ => resolvedExceptNatural
|
||||
}
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
def simpleEstimation: Statistics = joinType match {
|
||||
case LeftAnti | LeftSemi =>
|
||||
// LeftSemi and LeftAnti won't ever be bigger than left
|
||||
left.stats(conf)
|
||||
left.stats
|
||||
case _ =>
|
||||
// Make sure we don't propagate isBroadcastable in other joins, because
|
||||
// they could explode the size.
|
||||
val stats = super.computeStats(conf)
|
||||
val stats = super.computeStats
|
||||
stats.copy(hints = stats.hints.resetForJoin())
|
||||
}
|
||||
|
||||
if (conf.cboEnabled) {
|
||||
JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation)
|
||||
JoinEstimation.estimate(this).getOrElse(simpleEstimation)
|
||||
} else {
|
||||
simpleEstimation
|
||||
}
|
||||
|
@ -523,7 +522,7 @@ case class Range(
|
|||
|
||||
override def newInstance(): Range = copy(output = output.map(_.newInstance()))
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
val sizeInBytes = LongType.defaultSize * numElements
|
||||
Statistics( sizeInBytes = sizeInBytes )
|
||||
}
|
||||
|
@ -556,20 +555,20 @@ case class Aggregate(
|
|||
child.constraints.union(getAliasedConstraints(nonAgg))
|
||||
}
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
def simpleEstimation: Statistics = {
|
||||
if (groupingExpressions.isEmpty) {
|
||||
Statistics(
|
||||
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
|
||||
rowCount = Some(1),
|
||||
hints = child.stats(conf).hints)
|
||||
hints = child.stats.hints)
|
||||
} else {
|
||||
super.computeStats(conf)
|
||||
super.computeStats
|
||||
}
|
||||
}
|
||||
|
||||
if (conf.cboEnabled) {
|
||||
AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation)
|
||||
AggregateEstimation.estimate(this).getOrElse(simpleEstimation)
|
||||
} else {
|
||||
simpleEstimation
|
||||
}
|
||||
|
@ -672,8 +671,8 @@ case class Expand(
|
|||
override def references: AttributeSet =
|
||||
AttributeSet(projections.flatten.flatMap(_.references))
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length
|
||||
override def computeStats: Statistics = {
|
||||
val sizeInBytes = super.computeStats.sizeInBytes * projections.length
|
||||
Statistics(sizeInBytes = sizeInBytes)
|
||||
}
|
||||
|
||||
|
@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
|
|||
case _ => None
|
||||
}
|
||||
}
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
val limit = limitExpr.eval().asInstanceOf[Int]
|
||||
val childStats = child.stats(conf)
|
||||
val childStats = child.stats
|
||||
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
|
||||
// Don't propagate column stats, because we don't know the distribution after a limit operation
|
||||
Statistics(
|
||||
|
@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
|
|||
case _ => None
|
||||
}
|
||||
}
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
val limit = limitExpr.eval().asInstanceOf[Int]
|
||||
val childStats = child.stats(conf)
|
||||
val childStats = child.stats
|
||||
if (limit == 0) {
|
||||
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
|
||||
// (product of children).
|
||||
|
@ -832,9 +831,9 @@ case class Sample(
|
|||
|
||||
override def output: Seq[Attribute] = child.output
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
val ratio = upperBound - lowerBound
|
||||
val childStats = child.stats(conf)
|
||||
val childStats = child.stats
|
||||
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
|
||||
if (sizeInBytes == 0) {
|
||||
sizeInBytes = 1
|
||||
|
@ -898,7 +897,7 @@ case class RepartitionByExpression(
|
|||
case object OneRowRelation extends LeafNode {
|
||||
override def maxRows: Option[Long] = Some(1)
|
||||
override def output: Seq[Attribute] = Nil
|
||||
override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1)
|
||||
override def computeStats: Statistics = Statistics(sizeInBytes = 1)
|
||||
}
|
||||
|
||||
/** A logical plan for `dropDuplicates`. */
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
/**
|
||||
* A general hint for the child that is not yet resolved. This node is generated by the parser and
|
||||
|
@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
|
|||
|
||||
override lazy val canonicalized: LogicalPlan = child.canonicalized
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
val stats = child.stats(conf)
|
||||
override def computeStats: Statistics = {
|
||||
val stats = child.stats
|
||||
stats.copy(hints = hints)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
|
||||
object AggregateEstimation {
|
||||
|
@ -29,13 +28,13 @@ object AggregateEstimation {
|
|||
* Estimate the number of output rows based on column stats of group-by columns, and propagate
|
||||
* column stats for aggregate expressions.
|
||||
*/
|
||||
def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = {
|
||||
val childStats = agg.child.stats(conf)
|
||||
def estimate(agg: Aggregate): Option[Statistics] = {
|
||||
val childStats = agg.child.stats
|
||||
// Check if we have column stats for all group-by columns.
|
||||
val colStatsExist = agg.groupingExpressions.forall { e =>
|
||||
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
|
||||
}
|
||||
if (rowCountsExist(conf, agg.child) && colStatsExist) {
|
||||
if (rowCountsExist(agg.child) && colStatsExist) {
|
||||
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
|
||||
// the data contains all combinations of distinct values of group-by columns.
|
||||
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
|
||||
|
|
|
@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.{DecimalType, _}
|
||||
|
||||
|
||||
object EstimationUtils {
|
||||
|
||||
/** Check if each plan has rowCount in its statistics. */
|
||||
def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean =
|
||||
plans.forall(_.stats(conf).rowCount.isDefined)
|
||||
def rowCountsExist(plans: LogicalPlan*): Boolean =
|
||||
plans.forall(_.stats.rowCount.isDefined)
|
||||
|
||||
/** Check if each attribute has column stat in the corresponding statistics. */
|
||||
def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = {
|
||||
|
|
|
@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging {
|
||||
case class FilterEstimation(plan: Filter) extends Logging {
|
||||
|
||||
private val childStats = plan.child.stats(catalystConf)
|
||||
private val childStats = plan.child.stats
|
||||
|
||||
private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
|
|||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
|
||||
object JoinEstimation extends Logging {
|
||||
|
@ -34,12 +33,12 @@ object JoinEstimation extends Logging {
|
|||
* Estimate statistics after join. Return `None` if the join type is not supported, or we don't
|
||||
* have enough statistics for estimation.
|
||||
*/
|
||||
def estimate(conf: SQLConf, join: Join): Option[Statistics] = {
|
||||
def estimate(join: Join): Option[Statistics] = {
|
||||
join.joinType match {
|
||||
case Inner | Cross | LeftOuter | RightOuter | FullOuter =>
|
||||
InnerOuterEstimation(conf, join).doEstimate()
|
||||
InnerOuterEstimation(join).doEstimate()
|
||||
case LeftSemi | LeftAnti =>
|
||||
LeftSemiAntiEstimation(conf, join).doEstimate()
|
||||
LeftSemiAntiEstimation(join).doEstimate()
|
||||
case _ =>
|
||||
logDebug(s"[CBO] Unsupported join type: ${join.joinType}")
|
||||
None
|
||||
|
@ -47,16 +46,16 @@ object JoinEstimation extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging {
|
||||
case class InnerOuterEstimation(join: Join) extends Logging {
|
||||
|
||||
private val leftStats = join.left.stats(conf)
|
||||
private val rightStats = join.right.stats(conf)
|
||||
private val leftStats = join.left.stats
|
||||
private val rightStats = join.right.stats
|
||||
|
||||
/**
|
||||
* Estimate output size and number of rows after a join operator, and update output column stats.
|
||||
*/
|
||||
def doEstimate(): Option[Statistics] = join match {
|
||||
case _ if !rowCountsExist(conf, join.left, join.right) =>
|
||||
case _ if !rowCountsExist(join.left, join.right) =>
|
||||
None
|
||||
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
|
||||
|
@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) {
|
||||
case class LeftSemiAntiEstimation(join: Join) {
|
||||
def doEstimate(): Option[Statistics] = {
|
||||
// TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic
|
||||
// column stats. Now we just propagate the statistics from left side. We should do more
|
||||
// accurate estimation when advanced stats (e.g. histograms) are available.
|
||||
if (rowCountsExist(conf, join.left)) {
|
||||
val leftStats = join.left.stats(conf)
|
||||
if (rowCountsExist(join.left)) {
|
||||
val leftStats = join.left.stats
|
||||
// Propagate the original column stats for cartesian product
|
||||
val outputRows = leftStats.rowCount.get
|
||||
Some(Statistics(
|
||||
|
|
|
@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
object ProjectEstimation {
|
||||
import EstimationUtils._
|
||||
|
||||
def estimate(conf: SQLConf, project: Project): Option[Statistics] = {
|
||||
if (rowCountsExist(conf, project.child)) {
|
||||
val childStats = project.child.stats(conf)
|
||||
def estimate(project: Project): Option[Statistics] = {
|
||||
if (rowCountsExist(project.child)) {
|
||||
val childStats = project.child.stats
|
||||
val inputAttrStats = childStats.attributeStats
|
||||
// Match alias with its child's column stat
|
||||
val aliasStats = project.expressions.collect {
|
||||
|
|
|
@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest {
|
|||
comparePlans(optimized, expected)
|
||||
|
||||
val broadcastChildren = optimized.collect {
|
||||
case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r
|
||||
case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r
|
||||
}
|
||||
assert(broadcastChildren.size == 1)
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest {
|
|||
}
|
||||
|
||||
test("full outer join where neither side is limited and both sides have same statistics") {
|
||||
assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes)
|
||||
assert(x.stats.sizeInBytes === y.stats.sizeInBytes)
|
||||
val originalQuery = x.join(y, FullOuter).limit(1)
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze
|
||||
|
@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest {
|
|||
|
||||
test("full outer join where neither side is limited and left side has larger statistics") {
|
||||
val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x)
|
||||
assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes)
|
||||
assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes)
|
||||
val originalQuery = xBig.join(y, FullOuter).limit(1)
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze
|
||||
|
@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest {
|
|||
|
||||
test("full outer join where neither side is limited and right side has larger statistics") {
|
||||
val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y)
|
||||
assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes)
|
||||
assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes)
|
||||
val originalQuery = x.join(yBig, FullOuter).limit(1)
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze
|
||||
|
|
|
@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
|
|||
size = Some(4 * (8 + 4)),
|
||||
attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
|
||||
|
||||
val noGroupAgg = Aggregate(groupingExpressions = Nil,
|
||||
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
|
||||
assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) ==
|
||||
// overhead + count result size
|
||||
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
|
||||
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
|
||||
try {
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
|
||||
val noGroupAgg = Aggregate(groupingExpressions = Nil,
|
||||
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
|
||||
assert(noGroupAgg.stats ==
|
||||
// overhead + count result size
|
||||
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
|
||||
|
||||
val hasGroupAgg = Aggregate(groupingExpressions = attributes,
|
||||
aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
|
||||
assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) ==
|
||||
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
|
||||
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
|
||||
val hasGroupAgg = Aggregate(groupingExpressions = attributes,
|
||||
aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
|
||||
assert(hasGroupAgg.stats ==
|
||||
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
|
||||
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
|
||||
} finally {
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
|
||||
}
|
||||
}
|
||||
|
||||
private def checkAggStats(
|
||||
|
@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
|
|||
rowCount = Some(expectedOutputRowCount),
|
||||
attributeStats = expectedAttrStats)
|
||||
|
||||
assert(testAgg.stats(conf) == expectedStats)
|
||||
assert(testAgg.stats == expectedStats)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
|
|||
val localLimit = LocalLimit(Literal(2), plan)
|
||||
val globalLimit = GlobalLimit(Literal(2), plan)
|
||||
// LocalLimit's stats is just its child's stats except column stats
|
||||
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
|
||||
checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
|
||||
checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
|
||||
}
|
||||
|
||||
test("limit estimation: limit > child's rowCount") {
|
||||
val localLimit = LocalLimit(Literal(20), plan)
|
||||
val globalLimit = GlobalLimit(Literal(20), plan)
|
||||
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
|
||||
checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
|
||||
// Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
|
||||
checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
|
||||
checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil)))
|
||||
}
|
||||
|
||||
test("limit estimation: limit = 0") {
|
||||
|
@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
|
|||
plan: LogicalPlan,
|
||||
expectedStatsCboOn: Statistics,
|
||||
expectedStatsCboOff: Statistics): Unit = {
|
||||
// Invalidate statistics
|
||||
plan.invalidateStatsCache()
|
||||
assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn)
|
||||
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
|
||||
try {
|
||||
// Invalidate statistics
|
||||
plan.invalidateStatsCache()
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
|
||||
assert(plan.stats == expectedStatsCboOn)
|
||||
|
||||
plan.invalidateStatsCache()
|
||||
assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff)
|
||||
plan.invalidateStatsCache()
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
|
||||
assert(plan.stats == expectedStatsCboOff)
|
||||
} finally {
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
|
||||
}
|
||||
}
|
||||
|
||||
/** Check estimated stats when it's the same whether cbo is turned on or off. */
|
||||
|
@ -135,6 +142,6 @@ private case class DummyLogicalPlan(
|
|||
cboStats: Statistics) extends LogicalPlan {
|
||||
override def output: Seq[Attribute] = Nil
|
||||
override def children: Seq[LogicalPlan] = Nil
|
||||
override def computeStats(conf: SQLConf): Statistics =
|
||||
override def computeStats: Statistics =
|
||||
if (conf.cboEnabled) cboStats else defaultStats
|
||||
}
|
||||
|
|
|
@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
|
|||
rowCount = Some(expectedRowCount),
|
||||
attributeStats = expectedAttributeMap)
|
||||
|
||||
val filterStats = filter.stats(conf)
|
||||
val filterStats = filter.stats
|
||||
assert(filterStats.sizeInBytes == expectedStats.sizeInBytes)
|
||||
assert(filterStats.rowCount == expectedStats.rowCount)
|
||||
val rowCountValue = filterStats.rowCount.getOrElse(0)
|
||||
|
|
|
@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
// Keep the column stat from both sides unchanged.
|
||||
attributeStats = AttributeMap(
|
||||
Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo)))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("disjoint inner join") {
|
||||
|
@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
sizeInBytes = 1,
|
||||
rowCount = Some(0),
|
||||
attributeStats = AttributeMap(Nil))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("disjoint left outer join") {
|
||||
|
@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
// Null count for right side columns = left row count
|
||||
Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5),
|
||||
nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("disjoint right outer join") {
|
||||
|
@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
// Null count for left side columns = right row count
|
||||
Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3),
|
||||
nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("disjoint full outer join") {
|
||||
|
@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3),
|
||||
nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5),
|
||||
nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("inner join") {
|
||||
|
@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
attributeStats = AttributeMap(
|
||||
Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat,
|
||||
nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4"))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("inner join with multiple equi-join keys") {
|
||||
|
@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
attributeStats = AttributeMap(
|
||||
Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1,
|
||||
nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2)))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("left outer join") {
|
||||
|
@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
attributeStats = AttributeMap(
|
||||
Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"),
|
||||
nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat)))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("right outer join") {
|
||||
|
@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
attributeStats = AttributeMap(
|
||||
Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat,
|
||||
nameToColInfo("key-1-2"), nameToColInfo("key-2-3"))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("full outer join") {
|
||||
|
@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
// Keep the column stat from both sides unchanged.
|
||||
attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"),
|
||||
nameToColInfo("key-1-2"), nameToColInfo("key-2-3"))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("left semi/anti join") {
|
||||
|
@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
sizeInBytes = 3 * (8 + 4 * 2),
|
||||
rowCount = Some(3),
|
||||
attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))),
|
||||
rowCount = Some(1),
|
||||
attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1))))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
|
|||
sizeInBytes = 1,
|
||||
rowCount = Some(0),
|
||||
attributeStats = AttributeMap(Nil))
|
||||
assert(join.stats(conf) == expectedStats)
|
||||
assert(join.stats == expectedStats)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
|
|||
sizeInBytes = 2 * (8 + 4 + 4),
|
||||
rowCount = Some(2),
|
||||
attributeStats = expectedAttrStats)
|
||||
assert(proj.stats(conf) == expectedStats)
|
||||
assert(proj.stats == expectedStats)
|
||||
}
|
||||
|
||||
test("project on empty table") {
|
||||
|
@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
|
|||
sizeInBytes = expectedSize,
|
||||
rowCount = Some(expectedRowCount),
|
||||
attributeStats = projectAttrMap)
|
||||
assert(proj.stats(conf) == expectedStats)
|
||||
assert(proj.stats == expectedStats)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED}
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType}
|
||||
|
||||
|
||||
trait StatsEstimationTestBase extends SparkFunSuite {
|
||||
|
||||
/** Enable stats estimation based on CBO. */
|
||||
protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true)
|
||||
var originalValue: Boolean = false
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
// Enable stats estimation based on CBO.
|
||||
originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
|
||||
super.afterAll()
|
||||
}
|
||||
|
||||
def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
|
||||
// For UTF8String: base + offset + numBytes
|
||||
|
@ -55,7 +65,7 @@ case class StatsTestPlan(
|
|||
attributeStats: AttributeMap[ColumnStat],
|
||||
size: Option[BigInt] = None) extends LeafNode {
|
||||
override def output: Seq[Attribute] = outputList
|
||||
override def computeStats(conf: SQLConf): Statistics = Statistics(
|
||||
override def computeStats: Statistics = Statistics(
|
||||
// If sizeInBytes is useless in testing, we just use a fake value
|
||||
sizeInBytes = size.getOrElse(Int.MaxValue),
|
||||
rowCount = Some(rowCount),
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -89,7 +88,7 @@ case class ExternalRDD[T](
|
|||
|
||||
override protected def stringArgs: Iterator[Any] = Iterator(output)
|
||||
|
||||
@transient override def computeStats(conf: SQLConf): Statistics = Statistics(
|
||||
@transient override def computeStats: Statistics = Statistics(
|
||||
// TODO: Instead of returning a default value here, find a way to return a meaningful size
|
||||
// estimate for RDDs. See PR 1238 for more discussions.
|
||||
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
|
||||
|
@ -157,7 +156,7 @@ case class LogicalRDD(
|
|||
|
||||
override protected def stringArgs: Iterator[Any] = Iterator(output)
|
||||
|
||||
@transient override def computeStats(conf: SQLConf): Statistics = Statistics(
|
||||
@transient override def computeStats: Statistics = Statistics(
|
||||
// TODO: Instead of returning a default value here, find a way to return a meaningful size
|
||||
// estimate for RDDs. See PR 1238 for more discussions.
|
||||
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
|
||||
|
|
|
@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
|
|||
|
||||
def stringWithStats: String = {
|
||||
// trigger to compute stats for logical plans
|
||||
optimizedPlan.stats(sparkSession.sessionState.conf)
|
||||
optimizedPlan.stats
|
||||
|
||||
// only show optimized logical plan and physical plan
|
||||
s"""== Optimized Logical Plan ==
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.First
|
||||
import org.apache.spark.sql.catalyst.planning._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
|
@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
* Matches a plan whose output should be small enough to be used in broadcast join.
|
||||
*/
|
||||
private def canBroadcast(plan: LogicalPlan): Boolean = {
|
||||
plan.stats(conf).hints.broadcast ||
|
||||
(plan.stats(conf).sizeInBytes >= 0 &&
|
||||
plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold)
|
||||
plan.stats.hints.broadcast ||
|
||||
(plan.stats.sizeInBytes >= 0 &&
|
||||
plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
* dynamic.
|
||||
*/
|
||||
private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = {
|
||||
plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
|
||||
plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
* use the size of bytes here as estimation.
|
||||
*/
|
||||
private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
|
||||
a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes
|
||||
a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes
|
||||
}
|
||||
|
||||
private def canBuildRight(joinType: JoinType): Boolean = joinType match {
|
||||
|
@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
|
||||
case logical.Join(left, right, joinType, condition) =>
|
||||
val buildSide =
|
||||
if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) {
|
||||
if (right.stats.sizeInBytes <= left.stats.sizeInBytes) {
|
||||
BuildRight
|
||||
} else {
|
||||
BuildLeft
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.plans.logical
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Statistics
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.LongAccumulator
|
||||
|
||||
|
@ -70,7 +69,7 @@ case class InMemoryRelation(
|
|||
|
||||
@transient val partitionStatistics = new PartitionStatistics(output)
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics = {
|
||||
override def computeStats: Statistics = {
|
||||
if (batchStats.value == 0L) {
|
||||
// Underlying columnar RDD hasn't been materialized, no useful statistics information
|
||||
// available, return the default statistics.
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
|
|||
import org.apache.spark.sql.catalyst.catalog.CatalogTable
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.BaseRelation
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -46,7 +45,7 @@ case class LogicalRelation(
|
|||
// Only care about relation when canonicalizing.
|
||||
override def preCanonicalized: LogicalPlan = copy(catalogTable = None)
|
||||
|
||||
@transient override def computeStats(conf: SQLConf): Statistics = {
|
||||
@transient override def computeStats: Statistics = {
|
||||
catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse(
|
||||
Statistics(sizeInBytes = relation.sizeInBytes))
|
||||
}
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
|
|||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode
|
|||
|
||||
private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
|
||||
|
||||
override def computeStats(conf: SQLConf): Statistics =
|
||||
override def computeStats: Statistics =
|
||||
Statistics(sizePerRow * sink.allData.size)
|
||||
}
|
||||
|
|
|
@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|
|||
spark.table("testData").queryExecution.withCachedData.collect {
|
||||
case cached: InMemoryRelation =>
|
||||
val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum
|
||||
assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes)
|
||||
assert(cached.stats.sizeInBytes === actualSizeInBytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
// instead of Int for avoiding possible overflow.
|
||||
val ds = (0 to 10000).map( i =>
|
||||
(i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS()
|
||||
val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes
|
||||
val sizeInBytes = ds.logicalPlan.stats.sizeInBytes
|
||||
// sizeInBytes is 2404280404, before the fix, it overflows to a negative number
|
||||
assert(sizeInBytes > 0)
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
|
|||
setupTestData()
|
||||
|
||||
def statisticSizeInByte(df: DataFrame): BigInt = {
|
||||
df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes
|
||||
df.queryExecution.optimizedPlan.stats.sizeInBytes
|
||||
}
|
||||
|
||||
test("equi-join is hash-join") {
|
||||
|
|
|
@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
|
|||
val df = df1.join(df2, Seq("k"), "left")
|
||||
|
||||
val sizes = df.queryExecution.analyzed.collect { case g: Join =>
|
||||
g.stats(conf).sizeInBytes
|
||||
g.stats.sizeInBytes
|
||||
}
|
||||
|
||||
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
|
||||
|
@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
|
|||
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
|
||||
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
|
||||
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
|
||||
assert(df.queryExecution.analyzed.stats(conf).sizeInBytes >
|
||||
assert(df.queryExecution.analyzed.stats.sizeInBytes >
|
||||
spark.sessionState.conf.autoBroadcastJoinThreshold)
|
||||
assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes >
|
||||
assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes >
|
||||
spark.sessionState.conf.autoBroadcastJoinThreshold)
|
||||
}
|
||||
|
||||
|
@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
|
|||
test("SPARK-18856: non-empty partitioned table should not report zero size") {
|
||||
withTable("ds_tbl", "hive_tbl") {
|
||||
spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl")
|
||||
val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf)
|
||||
val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats
|
||||
assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.")
|
||||
|
||||
if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") {
|
||||
sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)")
|
||||
sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1")
|
||||
val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf)
|
||||
val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats
|
||||
assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.")
|
||||
}
|
||||
}
|
||||
|
@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
|
|||
assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat))
|
||||
|
||||
// Check relation statistics
|
||||
assert(relation.stats(conf).sizeInBytes == 0)
|
||||
assert(relation.stats(conf).rowCount == Some(0))
|
||||
assert(relation.stats(conf).attributeStats.size == 1)
|
||||
val (attribute, colStat) = relation.stats(conf).attributeStats.head
|
||||
assert(relation.stats.sizeInBytes == 0)
|
||||
assert(relation.stats.rowCount == Some(0))
|
||||
assert(relation.stats.attributeStats.size == 1)
|
||||
val (attribute, colStat) = relation.stats.attributeStats.head
|
||||
assert(attribute.name == "c1")
|
||||
assert(colStat == emptyColStat)
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
|
|||
.toDF().createOrReplaceTempView("sizeTst")
|
||||
spark.catalog.cacheTable("sizeTst")
|
||||
assert(
|
||||
spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes >
|
||||
spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes >
|
||||
spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {
|
|||
})
|
||||
val totalSize = allFiles.map(_.length()).sum
|
||||
val df = spark.read.parquet(dir.toString)
|
||||
assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize))
|
||||
assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
|
|||
|
||||
// Before adding data, check output
|
||||
checkAnswer(sink.allData, Seq.empty)
|
||||
assert(plan.stats(sqlConf).sizeInBytes === 0)
|
||||
assert(plan.stats.sizeInBytes === 0)
|
||||
|
||||
sink.addBatch(0, 1 to 3)
|
||||
plan.invalidateStatsCache()
|
||||
assert(plan.stats(sqlConf).sizeInBytes === 12)
|
||||
assert(plan.stats.sizeInBytes === 12)
|
||||
|
||||
sink.addBatch(1, 4 to 6)
|
||||
plan.invalidateStatsCache()
|
||||
assert(plan.stats(sqlConf).sizeInBytes === 24)
|
||||
assert(plan.stats.sizeInBytes === 24)
|
||||
}
|
||||
|
||||
ignore("stress test") {
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
/**
|
||||
* A collection of sample data used in SQL tests.
|
||||
|
@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf
|
|||
private[sql] trait SQLTestData { self =>
|
||||
protected def spark: SparkSession
|
||||
|
||||
protected def sqlConf: SQLConf = spark.sessionState.conf
|
||||
|
||||
// Helper object to import SQL implicits without a concrete SQLContext
|
||||
private object internalImplicits extends SQLImplicits {
|
||||
protected override def _sqlContext: SQLContext = self.spark.sqlContext
|
||||
|
|
|
@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
|
|||
Some(partitionSchema))
|
||||
|
||||
val logicalRelation = cached.getOrElse {
|
||||
val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong
|
||||
val sizeInBytes = relation.stats.sizeInBytes.toLong
|
||||
val fileIndex = {
|
||||
val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes)
|
||||
if (lazyPruningEnabled) {
|
||||
|
|
|
@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
|
|||
assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0")
|
||||
assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0")
|
||||
|
||||
val sizeInBytes = relation.stats(conf).sizeInBytes
|
||||
val sizeInBytes = relation.stats.sizeInBytes
|
||||
assert(sizeInBytes === BigInt(file1.length() + file2.length()))
|
||||
}
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
|
|||
|
||||
test("analyze Hive serde tables") {
|
||||
def queryTotalSize(tableName: String): BigInt =
|
||||
spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes
|
||||
spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes
|
||||
|
||||
// Non-partitioned table
|
||||
sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect()
|
||||
|
@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
|
|||
test("estimates the size of a test Hive serde tables") {
|
||||
val df = sql("""SELECT * FROM src""")
|
||||
val sizes = df.queryExecution.analyzed.collect {
|
||||
case relation: CatalogRelation => relation.stats(conf).sizeInBytes
|
||||
case relation: CatalogRelation => relation.stats.sizeInBytes
|
||||
}
|
||||
assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}")
|
||||
assert(sizes(0).equals(BigInt(5812)),
|
||||
|
@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
|
|||
|
||||
// Assert src has a size smaller than the threshold.
|
||||
val sizes = df.queryExecution.analyzed.collect {
|
||||
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes
|
||||
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes
|
||||
}
|
||||
assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold
|
||||
&& sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold,
|
||||
|
@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
|
|||
|
||||
// Assert src has a size smaller than the threshold.
|
||||
val sizes = df.queryExecution.analyzed.collect {
|
||||
case relation: CatalogRelation => relation.stats(conf).sizeInBytes
|
||||
case relation: CatalogRelation => relation.stats.sizeInBytes
|
||||
}
|
||||
assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold
|
||||
&& sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold,
|
||||
|
|
|
@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
|
|||
case relation: LogicalRelation => relation
|
||||
}
|
||||
assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
|
||||
val size2 = relations(0).computeStats(conf).sizeInBytes
|
||||
val size2 = relations(0).computeStats.sizeInBytes
|
||||
assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
|
||||
assert(size2 < tableStats.get.sizeInBytes)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue