From 6f8c62047cea125d52af5dad7fb5ad3eadb7f7d0 Mon Sep 17 00:00:00 2001 From: Alkis Polyzotis Date: Sat, 5 Jun 2021 14:25:33 -0500 Subject: [PATCH] [SPARK-35558] Optimizes for multi-quantile retrieval ### What changes were proposed in this pull request? Optimizes the retrieval of approximate quantiles for an array of percentiles. * Adds an overload for QuantileSummaries.query that accepts an array of percentiles and optimizes the computation to do a single pass over the sketch and avoid redundant computation. * Modifies the ApproximatePercentiles operator to call into the new method. All formatting changes are the result of running ./dev/scalafmt ### Why are the changes needed? The existing implementation does repeated calls per input percentile resulting in redundant computation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests for the new method. Closes #32700 from alkispoly-db/spark_35558_approx_quants_array. Authored-by: Alkis Polyzotis Signed-off-by: Sean Owen --- .../aggregate/ApproximatePercentile.scala | 11 +- .../sql/catalyst/util/QuantileSummaries.scala | 107 +++++++++++++----- .../util/QuantileSummariesSuite.scala | 79 +++++++++---- .../sql/execution/stat/StatFunctions.scala | 7 +- .../apache/spark/sql/DataFrameStatSuite.scala | 2 +- 5 files changed, 149 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 38d8d7d71e..78e64bfcd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -261,19 +261,12 @@ object ApproximatePercentile { * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75)) * }}} */ - def getPercentiles(percentages: Array[Double]): Array[Double] = { + def getPercentiles(percentages: Array[Double]): Seq[Double] = { if (!isCompressed) compress() if (summaries.count == 0 || percentages.length == 0) { Array.emptyDoubleArray } else { - val result = new Array[Double](percentages.length) - var i = 0 - while (i < percentages.length) { - // Since summaries.count != 0, the query here never return None. - result(i) = summaries.query(percentages(i)).get - i += 1 - } - result + summaries.query(percentages).get } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index addf1408a3..e0cd6139a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -229,46 +229,99 @@ class QuantileSummaries( } /** - * Runs a query for a given quantile. + * Finds the approximate quantile for a percentile, starting at a specific index in the summary. + * This is a helper method that is called as we are making a pass over the summary and a sorted + * sequence of input percentiles. + * + * @param index The point at which to start scanning the summary for an approximate value. + * @param minRankAtIndex The accumulated minimum rank at the given index. + * @param targetError Target error from the summary. + * @param percentile The percentile whose value is computed. + * @return A tuple (i, r, a) where: i is the updated index for the next call, r is the updated + * rank at i, and a is the approximate quantile. + */ + private def findApproxQuantile( + index: Int, + minRankAtIndex: Long, + targetError: Double, + percentile: Double): (Int, Long, Double) = { + var curSample = sampled(index) + val rank = math.ceil(percentile * count).toLong + var i = index + var minRank = minRankAtIndex + while (i < sampled.length - 1) { + val maxRank = minRank + curSample.delta + if (maxRank - targetError <= rank && rank <= minRank + targetError) { + return (i, minRank, curSample.value) + } else { + i += 1 + curSample = sampled(i) + minRank += curSample.g + } + } + (sampled.length - 1, 0, sampled.last.value) + } + + /** + * Runs a query for a given sequence of percentiles. * The result follows the approximation guarantees detailed above. * The query can only be run on a compressed summary: you need to call compress() before using * it. * - * @param quantile the target quantile - * @return + * @param percentiles the target percentiles + * @return the corresponding approximate quantiles, in the same order as the input */ - def query(quantile: Double): Option[Double] = { - require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") - require(headSampled.isEmpty, + def query(percentiles: Seq[Double]): Option[Seq[Double]] = { + percentiles.foreach(p => + require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0, 1.0]")) + require( + headSampled.isEmpty, "Cannot operate on an uncompressed summary, call compress() first") if (sampled.isEmpty) return None - if (quantile <= relativeError) { - return Some(sampled.head.value) - } + val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) => + currentMax.max(stats.delta + stats.g)) / 2 - if (quantile >= 1 - relativeError) { - return Some(sampled.last.value) - } - - // Target rank - val rank = math.ceil(quantile * count).toLong - val targetError = sampled.map(s => s.delta + s.g).max / 2 + // Index to track the current sample + var index = 0 // Minimum rank at current sample - var minRank = 0L - var i = 0 - while (i < sampled.length - 1) { - val curSample = sampled(i) - minRank += curSample.g - val maxRank = minRank + curSample.delta - if (maxRank - targetError <= rank && rank <= minRank + targetError) { - return Some(curSample.value) - } - i += 1 + var minRank = sampled(0).g + + val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1) + val result = Array.fill(percentiles.length)(0.0) + sortedPercentiles.foreach { + case (percentile, pos) => + if (percentile <= relativeError) { + result(pos) = sampled.head.value + } else if (percentile >= 1 - relativeError) { + result(pos) = sampled.last.value + } else { + val (newIndex, newMinRank, approxQuantile) = + findApproxQuantile(index, minRank, targetError, percentile) + index = newIndex + minRank = newMinRank + result(pos) = approxQuantile + } } - Some(sampled.last.value) + Some(result) } + + /** + * Runs a query for a given percentile. + * The result follows the approximation guarantees detailed above. + * The query can only be run on a compressed summary: you need to call compress() before using + * it. + * + * @param percentile the target percentile + * @return the corresponding approximate quantile + */ + def query(percentile: Double): Option[Double] = + query(Seq(percentile)) match { + case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head) + case _ => None + } + } object QuantileSummaries { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index e53d0bbccc..018db3aed7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.SparkFunSuite @@ -54,25 +55,51 @@ class QuantileSummariesSuite extends SparkFunSuite { summary } - private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { + private def validateQuantileApproximation( + approx: Double, + percentile: Double, + data: Seq[Double], + summary: QuantileSummaries): Unit = { + assert(data.nonEmpty) + + val rankOfValue = data.count(_ <= approx) + val rankOfPreValue = data.count(_ < approx) + // `rankOfValue` is the last position of the quantile value. If the input repeats the value + // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's + // improper to choose the last position as its rank. Instead, we get the rank by averaging + // `rankOfValue` and `rankOfPreValue`. + val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) + val lower = math.floor((percentile - summary.relativeError) * data.size) + val upper = math.ceil((percentile + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested percentile: $percentile, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) + } + + private def checkQuantile( + percentile: Double, + data: Seq[Double], + summary: QuantileSummaries): Unit = { if (data.nonEmpty) { - val approx = summary.query(quant).get - // Get the rank of the approximation. - val rankOfValue = data.count(_ <= approx) - val rankOfPreValue = data.count(_ < approx) - // `rankOfValue` is the last position of the quantile value. If the input repeats the value - // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's - // improper to choose the last position as its rank. Instead, we get the rank by averaging - // `rankOfValue` and `rankOfPreValue`. - val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) - val lower = math.floor((quant - summary.relativeError) * data.size) - val upper = math.ceil((quant + summary.relativeError) * data.size) - val msg = - s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" - assert(rank >= lower, msg) - assert(rank <= upper, msg) + val approx = summary.query(percentile).get + validateQuantileApproximation(approx, percentile, data, summary) } else { - assert(summary.query(quant).isEmpty) + assert(summary.query(percentile).isEmpty) + } + } + + private def checkQuantiles( + percentiles: Seq[Double], + data: Seq[Double], + summary: QuantileSummaries): Unit = { + if (data.nonEmpty) { + val approx = summary.query(percentiles).get + for ((q, a) <- percentiles zip approx) { + validateQuantileApproximation(a, q, data, summary) + } + } else { + assert(summary.query(percentiles).isEmpty) } } @@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.5, data, s) checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) + checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s) + checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s) } test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " + @@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.5, data, s) checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) + checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s) + checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s) } test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") { @@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.5, emptyData, s) checkQuantile(0.1, emptyData, s) checkQuantile(0.001, emptyData, s) + checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s) + checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s) } } @@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.5, data, s) checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) + checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s) + checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s) } val (data11, data12) = { @@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.5, data, s) checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) + checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s) + checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s) } // length of data21 is 4 * length of data22 @@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite { val s2 = buildSummary(data22, epsi, compression) val s = s1.merge(s2) // Check all quantiles + val percentiles = ArrayBuffer[Double]() for (queryRank <- 1 to n) { - val queryQuantile = queryRank.toDouble / n.toDouble - checkQuantile(queryQuantile, data, s) + val percentile = queryRank.toDouble / n.toDouble + checkQuantile(percentile, data, s) + percentiles += percentile } + checkQuantiles(percentiles.toSeq, data, s) + checkQuantiles(percentiles.reverse.toSeq, data, s) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 0a9954e679..5dc0ff0ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -102,7 +102,12 @@ object StatFunctions extends Logging { } val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge) - summaries.map { summary => probabilities.flatMap(summary.query) } + summaries.map { + summary => summary.query(probabilities) match { + case Some(q) => q + case None => Seq() + } + } } /** Calculate the Pearson Correlation Coefficient for the given columns */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index cdd2568771..79ab3cda99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { val e = intercept[IllegalArgumentException] { df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head) } - assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]")) + assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]")) // relativeError should be non-negative val e2 = intercept[IllegalArgumentException] {