[SPARK-35558] Optimizes for multi-quantile retrieval
### What changes were proposed in this pull request? Optimizes the retrieval of approximate quantiles for an array of percentiles. * Adds an overload for QuantileSummaries.query that accepts an array of percentiles and optimizes the computation to do a single pass over the sketch and avoid redundant computation. * Modifies the ApproximatePercentiles operator to call into the new method. All formatting changes are the result of running ./dev/scalafmt ### Why are the changes needed? The existing implementation does repeated calls per input percentile resulting in redundant computation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests for the new method. Closes #32700 from alkispoly-db/spark_35558_approx_quants_array. Authored-by: Alkis Polyzotis <alkis.polyzotis@databricks.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
510bde460a
commit
6f8c62047c
|
@ -261,19 +261,12 @@ object ApproximatePercentile {
|
|||
* val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
|
||||
* }}}
|
||||
*/
|
||||
def getPercentiles(percentages: Array[Double]): Array[Double] = {
|
||||
def getPercentiles(percentages: Array[Double]): Seq[Double] = {
|
||||
if (!isCompressed) compress()
|
||||
if (summaries.count == 0 || percentages.length == 0) {
|
||||
Array.emptyDoubleArray
|
||||
} else {
|
||||
val result = new Array[Double](percentages.length)
|
||||
var i = 0
|
||||
while (i < percentages.length) {
|
||||
// Since summaries.count != 0, the query here never return None.
|
||||
result(i) = summaries.query(percentages(i)).get
|
||||
i += 1
|
||||
}
|
||||
result
|
||||
summaries.query(percentages).get
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -229,46 +229,99 @@ class QuantileSummaries(
|
|||
}
|
||||
|
||||
/**
|
||||
* Runs a query for a given quantile.
|
||||
* Finds the approximate quantile for a percentile, starting at a specific index in the summary.
|
||||
* This is a helper method that is called as we are making a pass over the summary and a sorted
|
||||
* sequence of input percentiles.
|
||||
*
|
||||
* @param index The point at which to start scanning the summary for an approximate value.
|
||||
* @param minRankAtIndex The accumulated minimum rank at the given index.
|
||||
* @param targetError Target error from the summary.
|
||||
* @param percentile The percentile whose value is computed.
|
||||
* @return A tuple (i, r, a) where: i is the updated index for the next call, r is the updated
|
||||
* rank at i, and a is the approximate quantile.
|
||||
*/
|
||||
private def findApproxQuantile(
|
||||
index: Int,
|
||||
minRankAtIndex: Long,
|
||||
targetError: Double,
|
||||
percentile: Double): (Int, Long, Double) = {
|
||||
var curSample = sampled(index)
|
||||
val rank = math.ceil(percentile * count).toLong
|
||||
var i = index
|
||||
var minRank = minRankAtIndex
|
||||
while (i < sampled.length - 1) {
|
||||
val maxRank = minRank + curSample.delta
|
||||
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
|
||||
return (i, minRank, curSample.value)
|
||||
} else {
|
||||
i += 1
|
||||
curSample = sampled(i)
|
||||
minRank += curSample.g
|
||||
}
|
||||
}
|
||||
(sampled.length - 1, 0, sampled.last.value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a query for a given sequence of percentiles.
|
||||
* The result follows the approximation guarantees detailed above.
|
||||
* The query can only be run on a compressed summary: you need to call compress() before using
|
||||
* it.
|
||||
*
|
||||
* @param quantile the target quantile
|
||||
* @return
|
||||
* @param percentiles the target percentiles
|
||||
* @return the corresponding approximate quantiles, in the same order as the input
|
||||
*/
|
||||
def query(quantile: Double): Option[Double] = {
|
||||
require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
|
||||
require(headSampled.isEmpty,
|
||||
def query(percentiles: Seq[Double]): Option[Seq[Double]] = {
|
||||
percentiles.foreach(p =>
|
||||
require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0, 1.0]"))
|
||||
require(
|
||||
headSampled.isEmpty,
|
||||
"Cannot operate on an uncompressed summary, call compress() first")
|
||||
|
||||
if (sampled.isEmpty) return None
|
||||
|
||||
if (quantile <= relativeError) {
|
||||
return Some(sampled.head.value)
|
||||
}
|
||||
val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) =>
|
||||
currentMax.max(stats.delta + stats.g)) / 2
|
||||
|
||||
if (quantile >= 1 - relativeError) {
|
||||
return Some(sampled.last.value)
|
||||
}
|
||||
|
||||
// Target rank
|
||||
val rank = math.ceil(quantile * count).toLong
|
||||
val targetError = sampled.map(s => s.delta + s.g).max / 2
|
||||
// Index to track the current sample
|
||||
var index = 0
|
||||
// Minimum rank at current sample
|
||||
var minRank = 0L
|
||||
var i = 0
|
||||
while (i < sampled.length - 1) {
|
||||
val curSample = sampled(i)
|
||||
minRank += curSample.g
|
||||
val maxRank = minRank + curSample.delta
|
||||
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
|
||||
return Some(curSample.value)
|
||||
}
|
||||
i += 1
|
||||
var minRank = sampled(0).g
|
||||
|
||||
val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1)
|
||||
val result = Array.fill(percentiles.length)(0.0)
|
||||
sortedPercentiles.foreach {
|
||||
case (percentile, pos) =>
|
||||
if (percentile <= relativeError) {
|
||||
result(pos) = sampled.head.value
|
||||
} else if (percentile >= 1 - relativeError) {
|
||||
result(pos) = sampled.last.value
|
||||
} else {
|
||||
val (newIndex, newMinRank, approxQuantile) =
|
||||
findApproxQuantile(index, minRank, targetError, percentile)
|
||||
index = newIndex
|
||||
minRank = newMinRank
|
||||
result(pos) = approxQuantile
|
||||
}
|
||||
}
|
||||
Some(sampled.last.value)
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a query for a given percentile.
|
||||
* The result follows the approximation guarantees detailed above.
|
||||
* The query can only be run on a compressed summary: you need to call compress() before using
|
||||
* it.
|
||||
*
|
||||
* @param percentile the target percentile
|
||||
* @return the corresponding approximate quantile
|
||||
*/
|
||||
def query(percentile: Double): Option[Double] =
|
||||
query(Seq(percentile)) match {
|
||||
case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object QuantileSummaries {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.util
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
@ -54,25 +55,51 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
summary
|
||||
}
|
||||
|
||||
private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = {
|
||||
private def validateQuantileApproximation(
|
||||
approx: Double,
|
||||
percentile: Double,
|
||||
data: Seq[Double],
|
||||
summary: QuantileSummaries): Unit = {
|
||||
assert(data.nonEmpty)
|
||||
|
||||
val rankOfValue = data.count(_ <= approx)
|
||||
val rankOfPreValue = data.count(_ < approx)
|
||||
// `rankOfValue` is the last position of the quantile value. If the input repeats the value
|
||||
// chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's
|
||||
// improper to choose the last position as its rank. Instead, we get the rank by averaging
|
||||
// `rankOfValue` and `rankOfPreValue`.
|
||||
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
|
||||
val lower = math.floor((percentile - summary.relativeError) * data.size)
|
||||
val upper = math.ceil((percentile + summary.relativeError) * data.size)
|
||||
val msg =
|
||||
s"$rank not in [$lower $upper], requested percentile: $percentile, approx returned: $approx"
|
||||
assert(rank >= lower, msg)
|
||||
assert(rank <= upper, msg)
|
||||
}
|
||||
|
||||
private def checkQuantile(
|
||||
percentile: Double,
|
||||
data: Seq[Double],
|
||||
summary: QuantileSummaries): Unit = {
|
||||
if (data.nonEmpty) {
|
||||
val approx = summary.query(quant).get
|
||||
// Get the rank of the approximation.
|
||||
val rankOfValue = data.count(_ <= approx)
|
||||
val rankOfPreValue = data.count(_ < approx)
|
||||
// `rankOfValue` is the last position of the quantile value. If the input repeats the value
|
||||
// chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's
|
||||
// improper to choose the last position as its rank. Instead, we get the rank by averaging
|
||||
// `rankOfValue` and `rankOfPreValue`.
|
||||
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
|
||||
val lower = math.floor((quant - summary.relativeError) * data.size)
|
||||
val upper = math.ceil((quant + summary.relativeError) * data.size)
|
||||
val msg =
|
||||
s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx"
|
||||
assert(rank >= lower, msg)
|
||||
assert(rank <= upper, msg)
|
||||
val approx = summary.query(percentile).get
|
||||
validateQuantileApproximation(approx, percentile, data, summary)
|
||||
} else {
|
||||
assert(summary.query(quant).isEmpty)
|
||||
assert(summary.query(percentile).isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
private def checkQuantiles(
|
||||
percentiles: Seq[Double],
|
||||
data: Seq[Double],
|
||||
summary: QuantileSummaries): Unit = {
|
||||
if (data.nonEmpty) {
|
||||
val approx = summary.query(percentiles).get
|
||||
for ((q, a) <- percentiles zip approx) {
|
||||
validateQuantileApproximation(a, q, data, summary)
|
||||
}
|
||||
} else {
|
||||
assert(summary.query(percentiles).isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
checkQuantile(0.5, data, s)
|
||||
checkQuantile(0.1, data, s)
|
||||
checkQuantile(0.001, data, s)
|
||||
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||
}
|
||||
|
||||
test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " +
|
||||
|
@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
checkQuantile(0.5, data, s)
|
||||
checkQuantile(0.1, data, s)
|
||||
checkQuantile(0.001, data, s)
|
||||
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||
}
|
||||
|
||||
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") {
|
||||
|
@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
checkQuantile(0.5, emptyData, s)
|
||||
checkQuantile(0.1, emptyData, s)
|
||||
checkQuantile(0.001, emptyData, s)
|
||||
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s)
|
||||
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
checkQuantile(0.5, data, s)
|
||||
checkQuantile(0.1, data, s)
|
||||
checkQuantile(0.001, data, s)
|
||||
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||
}
|
||||
|
||||
val (data11, data12) = {
|
||||
|
@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
checkQuantile(0.5, data, s)
|
||||
checkQuantile(0.1, data, s)
|
||||
checkQuantile(0.001, data, s)
|
||||
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||
}
|
||||
|
||||
// length of data21 is 4 * length of data22
|
||||
|
@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
|||
val s2 = buildSummary(data22, epsi, compression)
|
||||
val s = s1.merge(s2)
|
||||
// Check all quantiles
|
||||
val percentiles = ArrayBuffer[Double]()
|
||||
for (queryRank <- 1 to n) {
|
||||
val queryQuantile = queryRank.toDouble / n.toDouble
|
||||
checkQuantile(queryQuantile, data, s)
|
||||
val percentile = queryRank.toDouble / n.toDouble
|
||||
checkQuantile(percentile, data, s)
|
||||
percentiles += percentile
|
||||
}
|
||||
checkQuantiles(percentiles.toSeq, data, s)
|
||||
checkQuantiles(percentiles.reverse.toSeq, data, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,7 +102,12 @@ object StatFunctions extends Logging {
|
|||
}
|
||||
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
|
||||
|
||||
summaries.map { summary => probabilities.flatMap(summary.query) }
|
||||
summaries.map {
|
||||
summary => summary.query(probabilities) match {
|
||||
case Some(q) => q
|
||||
case None => Seq()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Calculate the Pearson Correlation Coefficient for the given columns */
|
||||
|
|
|
@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
|
|||
val e = intercept[IllegalArgumentException] {
|
||||
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
|
||||
}
|
||||
assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]"))
|
||||
assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]"))
|
||||
|
||||
// relativeError should be non-negative
|
||||
val e2 = intercept[IllegalArgumentException] {
|
||||
|
|
Loading…
Reference in a new issue