diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 4cff72d45a..ca0775a2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.LongType /** * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. @@ -54,7 +52,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { - JoinEstimation.estimate(p).getOrElse(fallback(p)) + JoinEstimation(p).estimate.getOrElse(fallback(p)) } override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index dcbe36da91..b073108c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -28,60 +28,58 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -object JoinEstimation extends Logging { - /** - * Estimate statistics after join. Return `None` if the join type is not supported, or we don't - * have enough statistics for estimation. - */ - def estimate(join: Join): Option[Statistics] = { - join.joinType match { - case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(join).doEstimate() - case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(join).doEstimate() - case _ => - logDebug(s"[CBO] Unsupported join type: ${join.joinType}") - None - } - } -} - -case class InnerOuterEstimation(join: Join) extends Logging { +case class JoinEstimation(join: Join) extends Logging { private val leftStats = join.left.stats private val rightStats = join.right.stats + /** + * Estimate statistics after join. Return `None` if the join type is not supported, or we don't + * have enough statistics for estimation. + */ + def estimate: Option[Statistics] = { + join.joinType match { + case Inner | Cross | LeftOuter | RightOuter | FullOuter => + estimateInnerOuterJoin() + case LeftSemi | LeftAnti => + estimateLeftSemiAntiJoin() + case _ => + logDebug(s"[CBO] Unsupported join type: ${join.joinType}") + None + } + } + /** * Estimate output size and number of rows after a join operator, and update output column stats. */ - def doEstimate(): Option[Statistics] = join match { + private def estimateInnerOuterJoin(): Option[Statistics] = join match { case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) - val selectivity = joinSelectivity(joinKeyPairs) + val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) // 2. Estimate the number of output rows val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) // Make sure outputRows won't be too small based on join type. val outputRows = joinType match { case LeftOuter => // All rows from left side should be in the result. - leftRows.max(innerJoinedRows) + leftRows.max(numInnerJoinedRows) case RightOuter => // All rows from right side should be in the result. - rightRows.max(innerJoinedRows) + rightRows.max(numInnerJoinedRows) case FullOuter => // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) - leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows + leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows case _ => + assert(joinType == Inner || joinType == Cross) // Don't change for inner or cross join - innerJoinedRows + numInnerJoinedRows } // 3. Update statistics based on the output of join @@ -93,7 +91,7 @@ case class InnerOuterEstimation(join: Join) extends Logging { val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (selectivity == 0) { + } else if (numInnerJoinedRows == 0) { joinType match { // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side @@ -113,26 +111,28 @@ case class InnerOuterEstimation(join: Join) extends Logging { val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } - case _ => Nil + case _ => + assert(joinType == Inner || joinType == Cross) + Nil } - } else if (selectivity == 1) { + } else if (numInnerJoinedRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats inputAttrStats.toSeq } else { - val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { // For outer joins, don't update column stats from the outer side. case LeftOuter => fromLeft.map(a => (a, inputAttrStats(a))) ++ - updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, fromRight, inputAttrStats, keyStatsAfterJoin) case RightOuter => - updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ + updateOutputStats(outputRows, fromLeft, inputAttrStats, keyStatsAfterJoin) ++ fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => inputAttrStats.toSeq case _ => + assert(joinType == Inner || joinType == Cross) // Update column stats from both sides for inner or cross join. - updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin) } } @@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends Logging { // scalastyle:off /** * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: - * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of - * that column. The underlying assumption for this formula is: each value of the smaller domain - * is included in the larger domain. - * Generally, inner join with multiple join keys can also be estimated based on the above - * formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), + * where V is the number of distinct values (ndv) of that column. The underlying assumption for + * this formula is: each value of the smaller domain is included in the larger domain. + * + * Generally, inner join with multiple join keys can be estimated based on the above formula: * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) * However, the denominator can become very large and excessively reduce the result, so we use a * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + * + * That is, join estimation is based on the most selective join keys. We follow this strategy + * when different types of column statistics are available. E.g., if card1 is the cardinality + * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms + * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). + * + * @param keyPairs pairs of join keys + * + * @return join cardinality, and column stats for join keys after the join */ // scalastyle:on - def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { - var ndvDenom: BigInt = -1 + private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)]) + : (BigInt, AttributeMap[ColumnStat]) = { + // If there's no column stats available for join keys, estimate as cartesian product. + var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() var i = 0 - while(i < joinKeyPairs.length && ndvDenom != 0) { - val (leftKey, rightKey) = joinKeyPairs(i) + while(i < keyPairs.length && joinCard != 0) { + val (leftKey, rightKey) = keyPairs(i) // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) + val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) + keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) + // Return cardinality estimated from the most selective join keys. + if (card < joinCard) joinCard = card } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + joinCard = 0 } i += 1 } + (joinCard, AttributeMap(keyStatsAfterJoin.toSeq)) + } - if (ndvDenom < 0) { - // We can't find any join key pairs with column stats, estimate it as cartesian join. - 1 - } else if (ndvDenom == 0) { - // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - 0 - } else { - 1 / BigDecimal(ndvDenom) - } + /** Returns join cardinality and the column stat for this pair of join keys. */ + private def computeByNdv( + leftKey: AttributeReference, + rightKey: AttributeReference, + newMin: Option[Any], + newMax: Option[Any]): (BigInt, ColumnStat) = { + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + // Compute cardinality by the basic formula. + val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) + + // Get the intersected column stat. + val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + (ceil(card), newStats) } /** * Propagate or update column stats for output attributes. */ - private def updateAttrStats( + private def updateOutputStats( outputRows: BigInt, - attributes: Seq[Attribute], + output: Seq[Attribute], oldAttrStats: AttributeMap[ColumnStat], - joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + keyStatsAfterJoin: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - attributes.foreach { a => + output.foreach { a => // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + if (keyStatsAfterJoin.contains(a)) { + outputAttrStats += a -> keyStatsAfterJoin(a) } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount @@ -231,34 +257,6 @@ case class InnerOuterEstimation(join: Join) extends Logging { outputAttrStats } - /** Get intersected column stats for join keys. */ - private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) - : AttributeMap[ColumnStat] = { - - val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() - joinKeyPairs.foreach { case (leftKey, rightKey) => - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - // When we reach here, join selectivity is not zero, so each pair of join keys should be - // intersected. - assert(ValueInterval.isIntersected(lInterval, rInterval)) - - // Update intersected column stats - assert(leftKey.dataType.sameType(rightKey.dataType)) - val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) - val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - - intersectedStats.put(leftKey, newStats) - intersectedStats.put(rightKey, newStats) - } - AttributeMap(intersectedStats.toSeq) - } - private def extractJoinKeysWithColStats( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { @@ -270,10 +268,8 @@ case class InnerOuterEstimation(join: Join) extends Logging { if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } -} -case class LeftSemiAntiEstimation(join: Join) { - def doEstimate(): Option[Statistics] = { + private def estimateLeftSemiAntiJoin(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available.