[SPARK-22310][SQL] Refactor join estimation to incorporate estimation logic for different kinds of statistics

## What changes were proposed in this pull request?

The current join estimation logic is only based on basic column statistics (such as ndv, etc). If we want to add estimation for other kinds of statistics (such as histograms), it's not easy to incorporate into the current algorithm:
1. When we have multiple pairs of join keys, the current algorithm computes cardinality in a single formula. But if different join keys have different kinds of stats, the computation logic for each pair of join keys become different, so the previous formula does not apply.
2. Currently it computes cardinality and updates join keys' column stats separately. It's better to do these two steps together, since both computation and update logic are different for different kinds of stats.

## How was this patch tested?

Only refactor, covered by existing tests.

Author: Zhenhua Wang <wangzhenhua@huawei.com>

Closes #19531 from wzhfy/join_est_refactor.
This commit is contained in:
Zhenhua Wang 2017-10-31 11:13:48 +01:00 committed by Wenchen Fan
parent aa6db57e39
commit 59589bc654
2 changed files with 94 additions and 100 deletions

View file

@ -17,9 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.LongType
/**
* An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer.
@ -54,7 +52,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitIntersect(p: Intersect): Statistics = fallback(p)
override def visitJoin(p: Join): Statistics = {
JoinEstimation.estimate(p).getOrElse(fallback(p))
JoinEstimation(p).estimate.getOrElse(fallback(p))
}
override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p)

View file

@ -28,60 +28,58 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
object JoinEstimation extends Logging {
/**
* Estimate statistics after join. Return `None` if the join type is not supported, or we don't
* have enough statistics for estimation.
*/
def estimate(join: Join): Option[Statistics] = {
join.joinType match {
case Inner | Cross | LeftOuter | RightOuter | FullOuter =>
InnerOuterEstimation(join).doEstimate()
case LeftSemi | LeftAnti =>
LeftSemiAntiEstimation(join).doEstimate()
case _ =>
logDebug(s"[CBO] Unsupported join type: ${join.joinType}")
None
}
}
}
case class InnerOuterEstimation(join: Join) extends Logging {
case class JoinEstimation(join: Join) extends Logging {
private val leftStats = join.left.stats
private val rightStats = join.right.stats
/**
* Estimate statistics after join. Return `None` if the join type is not supported, or we don't
* have enough statistics for estimation.
*/
def estimate: Option[Statistics] = {
join.joinType match {
case Inner | Cross | LeftOuter | RightOuter | FullOuter =>
estimateInnerOuterJoin()
case LeftSemi | LeftAnti =>
estimateLeftSemiAntiJoin()
case _ =>
logDebug(s"[CBO] Unsupported join type: ${join.joinType}")
None
}
}
/**
* Estimate output size and number of rows after a join operator, and update output column stats.
*/
def doEstimate(): Option[Statistics] = join match {
private def estimateInnerOuterJoin(): Option[Statistics] = join match {
case _ if !rowCountsExist(join.left, join.right) =>
None
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
// 1. Compute join selectivity
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
val selectivity = joinSelectivity(joinKeyPairs)
val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs)
// 2. Estimate the number of output rows
val leftRows = leftStats.rowCount.get
val rightRows = rightStats.rowCount.get
val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity)
// Make sure outputRows won't be too small based on join type.
val outputRows = joinType match {
case LeftOuter =>
// All rows from left side should be in the result.
leftRows.max(innerJoinedRows)
leftRows.max(numInnerJoinedRows)
case RightOuter =>
// All rows from right side should be in the result.
rightRows.max(innerJoinedRows)
rightRows.max(numInnerJoinedRows)
case FullOuter =>
// T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B)
leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows
leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows
case _ =>
assert(joinType == Inner || joinType == Cross)
// Don't change for inner or cross join
innerJoinedRows
numInnerJoinedRows
}
// 3. Update statistics based on the output of join
@ -93,7 +91,7 @@ case class InnerOuterEstimation(join: Join) extends Logging {
val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) {
// The output is empty, we don't need to keep column stats.
Nil
} else if (selectivity == 0) {
} else if (numInnerJoinedRows == 0) {
joinType match {
// For outer joins, if the join selectivity is 0, the number of output rows is the
// same as that of the outer side. And column stats of join keys from the outer side
@ -113,26 +111,28 @@ case class InnerOuterEstimation(join: Join) extends Logging {
val oriColStat = inputAttrStats(a)
(a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows))
}
case _ => Nil
case _ =>
assert(joinType == Inner || joinType == Cross)
Nil
}
} else if (selectivity == 1) {
} else if (numInnerJoinedRows == leftRows * rightRows) {
// Cartesian product, just propagate the original column stats
inputAttrStats.toSeq
} else {
val joinKeyStats = getIntersectedStats(joinKeyPairs)
join.joinType match {
// For outer joins, don't update column stats from the outer side.
case LeftOuter =>
fromLeft.map(a => (a, inputAttrStats(a))) ++
updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats)
updateOutputStats(outputRows, fromRight, inputAttrStats, keyStatsAfterJoin)
case RightOuter =>
updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++
updateOutputStats(outputRows, fromLeft, inputAttrStats, keyStatsAfterJoin) ++
fromRight.map(a => (a, inputAttrStats(a)))
case FullOuter =>
inputAttrStats.toSeq
case _ =>
assert(joinType == Inner || joinType == Cross)
// Update column stats from both sides for inner or cross join.
updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats)
updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin)
}
}
@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends Logging {
// scalastyle:off
/**
* The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula:
* T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of
* that column. The underlying assumption for this formula is: each value of the smaller domain
* is included in the larger domain.
* Generally, inner join with multiple join keys can also be estimated based on the above
* formula:
* T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)),
* where V is the number of distinct values (ndv) of that column. The underlying assumption for
* this formula is: each value of the smaller domain is included in the larger domain.
*
* Generally, inner join with multiple join keys can be estimated based on the above formula:
* T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn)))
* However, the denominator can become very large and excessively reduce the result, so we use a
* conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator.
*
* That is, join estimation is based on the most selective join keys. We follow this strategy
* when different types of column statistics are available. E.g., if card1 is the cardinality
* estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms
* of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2).
*
* @param keyPairs pairs of join keys
*
* @return join cardinality, and column stats for join keys after the join
*/
// scalastyle:on
def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = {
var ndvDenom: BigInt = -1
private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)])
: (BigInt, AttributeMap[ColumnStat]) = {
// If there's no column stats available for join keys, estimate as cartesian product.
var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get
val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]()
var i = 0
while(i < joinKeyPairs.length && ndvDenom != 0) {
val (leftKey, rightKey) = joinKeyPairs(i)
while(i < keyPairs.length && joinCard != 0) {
val (leftKey, rightKey) = keyPairs(i)
// Check if the two sides are disjoint
val leftKeyStats = leftStats.attributeStats(leftKey)
val rightKeyStats = rightStats.attributeStats(rightKey)
val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType)
val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType)
val leftKeyStat = leftStats.attributeStats(leftKey)
val rightKeyStat = rightStats.attributeStats(rightKey)
val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType)
val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType)
if (ValueInterval.isIntersected(lInterval, rInterval)) {
// Get the largest ndv among pairs of join keys
val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount)
if (maxNdv > ndvDenom) ndvDenom = maxNdv
val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax)
keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat)
// Return cardinality estimated from the most selective join keys.
if (card < joinCard) joinCard = card
} else {
// Set ndvDenom to zero to indicate that this join should have no output
ndvDenom = 0
// One of the join key pairs is disjoint, thus the two sides of join is disjoint.
joinCard = 0
}
i += 1
}
(joinCard, AttributeMap(keyStatsAfterJoin.toSeq))
}
if (ndvDenom < 0) {
// We can't find any join key pairs with column stats, estimate it as cartesian join.
1
} else if (ndvDenom == 0) {
// One of the join key pairs is disjoint, thus the two sides of join is disjoint.
0
} else {
1 / BigDecimal(ndvDenom)
}
/** Returns join cardinality and the column stat for this pair of join keys. */
private def computeByNdv(
leftKey: AttributeReference,
rightKey: AttributeReference,
newMin: Option[Any],
newMax: Option[Any]): (BigInt, ColumnStat) = {
val leftKeyStat = leftStats.attributeStats(leftKey)
val rightKeyStat = rightStats.attributeStats(rightKey)
val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount)
// Compute cardinality by the basic formula.
val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv)
// Get the intersected column stat.
val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount)
val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
(ceil(card), newStats)
}
/**
* Propagate or update column stats for output attributes.
*/
private def updateAttrStats(
private def updateOutputStats(
outputRows: BigInt,
attributes: Seq[Attribute],
output: Seq[Attribute],
oldAttrStats: AttributeMap[ColumnStat],
joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = {
keyStatsAfterJoin: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = {
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
val leftRows = leftStats.rowCount.get
val rightRows = rightStats.rowCount.get
attributes.foreach { a =>
output.foreach { a =>
// check if this attribute is a join key
if (joinKeyStats.contains(a)) {
outputAttrStats += a -> joinKeyStats(a)
if (keyStatsAfterJoin.contains(a)) {
outputAttrStats += a -> keyStatsAfterJoin(a)
} else {
val oldColStat = oldAttrStats(a)
val oldNdv = oldColStat.distinctCount
@ -231,34 +257,6 @@ case class InnerOuterEstimation(join: Join) extends Logging {
outputAttrStats
}
/** Get intersected column stats for join keys. */
private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)])
: AttributeMap[ColumnStat] = {
val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]()
joinKeyPairs.foreach { case (leftKey, rightKey) =>
val leftKeyStats = leftStats.attributeStats(leftKey)
val rightKeyStats = rightStats.attributeStats(rightKey)
val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType)
val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType)
// When we reach here, join selectivity is not zero, so each pair of join keys should be
// intersected.
assert(ValueInterval.isIntersected(lInterval, rInterval))
// Update intersected column stats
assert(leftKey.dataType.sameType(rightKey.dataType))
val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen)
val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2
val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
intersectedStats.put(leftKey, newStats)
intersectedStats.put(rightKey, newStats)
}
AttributeMap(intersectedStats.toSeq)
}
private def extractJoinKeysWithColStats(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = {
@ -270,10 +268,8 @@ case class InnerOuterEstimation(join: Join) extends Logging {
if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk)
}
}
}
case class LeftSemiAntiEstimation(join: Join) {
def doEstimate(): Option[Statistics] = {
private def estimateLeftSemiAntiJoin(): Option[Statistics] = {
// TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic
// column stats. Now we just propagate the statistics from left side. We should do more
// accurate estimation when advanced stats (e.g. histograms) are available.