[SPARK-35558] Optimizes for multi-quantile retrieval
### What changes were proposed in this pull request? Optimizes the retrieval of approximate quantiles for an array of percentiles. * Adds an overload for QuantileSummaries.query that accepts an array of percentiles and optimizes the computation to do a single pass over the sketch and avoid redundant computation. * Modifies the ApproximatePercentiles operator to call into the new method. All formatting changes are the result of running ./dev/scalafmt ### Why are the changes needed? The existing implementation does repeated calls per input percentile resulting in redundant computation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests for the new method. Closes #32700 from alkispoly-db/spark_35558_approx_quants_array. Authored-by: Alkis Polyzotis <alkis.polyzotis@databricks.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
510bde460a
commit
6f8c62047c
|
@ -261,19 +261,12 @@ object ApproximatePercentile {
|
||||||
* val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
|
* val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
|
||||||
* }}}
|
* }}}
|
||||||
*/
|
*/
|
||||||
def getPercentiles(percentages: Array[Double]): Array[Double] = {
|
def getPercentiles(percentages: Array[Double]): Seq[Double] = {
|
||||||
if (!isCompressed) compress()
|
if (!isCompressed) compress()
|
||||||
if (summaries.count == 0 || percentages.length == 0) {
|
if (summaries.count == 0 || percentages.length == 0) {
|
||||||
Array.emptyDoubleArray
|
Array.emptyDoubleArray
|
||||||
} else {
|
} else {
|
||||||
val result = new Array[Double](percentages.length)
|
summaries.query(percentages).get
|
||||||
var i = 0
|
|
||||||
while (i < percentages.length) {
|
|
||||||
// Since summaries.count != 0, the query here never return None.
|
|
||||||
result(i) = summaries.query(percentages(i)).get
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -229,46 +229,99 @@ class QuantileSummaries(
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Runs a query for a given quantile.
|
* Finds the approximate quantile for a percentile, starting at a specific index in the summary.
|
||||||
|
* This is a helper method that is called as we are making a pass over the summary and a sorted
|
||||||
|
* sequence of input percentiles.
|
||||||
|
*
|
||||||
|
* @param index The point at which to start scanning the summary for an approximate value.
|
||||||
|
* @param minRankAtIndex The accumulated minimum rank at the given index.
|
||||||
|
* @param targetError Target error from the summary.
|
||||||
|
* @param percentile The percentile whose value is computed.
|
||||||
|
* @return A tuple (i, r, a) where: i is the updated index for the next call, r is the updated
|
||||||
|
* rank at i, and a is the approximate quantile.
|
||||||
|
*/
|
||||||
|
private def findApproxQuantile(
|
||||||
|
index: Int,
|
||||||
|
minRankAtIndex: Long,
|
||||||
|
targetError: Double,
|
||||||
|
percentile: Double): (Int, Long, Double) = {
|
||||||
|
var curSample = sampled(index)
|
||||||
|
val rank = math.ceil(percentile * count).toLong
|
||||||
|
var i = index
|
||||||
|
var minRank = minRankAtIndex
|
||||||
|
while (i < sampled.length - 1) {
|
||||||
|
val maxRank = minRank + curSample.delta
|
||||||
|
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
|
||||||
|
return (i, minRank, curSample.value)
|
||||||
|
} else {
|
||||||
|
i += 1
|
||||||
|
curSample = sampled(i)
|
||||||
|
minRank += curSample.g
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(sampled.length - 1, 0, sampled.last.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs a query for a given sequence of percentiles.
|
||||||
* The result follows the approximation guarantees detailed above.
|
* The result follows the approximation guarantees detailed above.
|
||||||
* The query can only be run on a compressed summary: you need to call compress() before using
|
* The query can only be run on a compressed summary: you need to call compress() before using
|
||||||
* it.
|
* it.
|
||||||
*
|
*
|
||||||
* @param quantile the target quantile
|
* @param percentiles the target percentiles
|
||||||
* @return
|
* @return the corresponding approximate quantiles, in the same order as the input
|
||||||
*/
|
*/
|
||||||
def query(quantile: Double): Option[Double] = {
|
def query(percentiles: Seq[Double]): Option[Seq[Double]] = {
|
||||||
require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
|
percentiles.foreach(p =>
|
||||||
require(headSampled.isEmpty,
|
require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0, 1.0]"))
|
||||||
|
require(
|
||||||
|
headSampled.isEmpty,
|
||||||
"Cannot operate on an uncompressed summary, call compress() first")
|
"Cannot operate on an uncompressed summary, call compress() first")
|
||||||
|
|
||||||
if (sampled.isEmpty) return None
|
if (sampled.isEmpty) return None
|
||||||
|
|
||||||
if (quantile <= relativeError) {
|
val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) =>
|
||||||
return Some(sampled.head.value)
|
currentMax.max(stats.delta + stats.g)) / 2
|
||||||
}
|
|
||||||
|
|
||||||
if (quantile >= 1 - relativeError) {
|
// Index to track the current sample
|
||||||
return Some(sampled.last.value)
|
var index = 0
|
||||||
}
|
|
||||||
|
|
||||||
// Target rank
|
|
||||||
val rank = math.ceil(quantile * count).toLong
|
|
||||||
val targetError = sampled.map(s => s.delta + s.g).max / 2
|
|
||||||
// Minimum rank at current sample
|
// Minimum rank at current sample
|
||||||
var minRank = 0L
|
var minRank = sampled(0).g
|
||||||
var i = 0
|
|
||||||
while (i < sampled.length - 1) {
|
val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1)
|
||||||
val curSample = sampled(i)
|
val result = Array.fill(percentiles.length)(0.0)
|
||||||
minRank += curSample.g
|
sortedPercentiles.foreach {
|
||||||
val maxRank = minRank + curSample.delta
|
case (percentile, pos) =>
|
||||||
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
|
if (percentile <= relativeError) {
|
||||||
return Some(curSample.value)
|
result(pos) = sampled.head.value
|
||||||
|
} else if (percentile >= 1 - relativeError) {
|
||||||
|
result(pos) = sampled.last.value
|
||||||
|
} else {
|
||||||
|
val (newIndex, newMinRank, approxQuantile) =
|
||||||
|
findApproxQuantile(index, minRank, targetError, percentile)
|
||||||
|
index = newIndex
|
||||||
|
minRank = newMinRank
|
||||||
|
result(pos) = approxQuantile
|
||||||
}
|
}
|
||||||
i += 1
|
|
||||||
}
|
}
|
||||||
Some(sampled.last.value)
|
Some(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs a query for a given percentile.
|
||||||
|
* The result follows the approximation guarantees detailed above.
|
||||||
|
* The query can only be run on a compressed summary: you need to call compress() before using
|
||||||
|
* it.
|
||||||
|
*
|
||||||
|
* @param percentile the target percentile
|
||||||
|
* @return the corresponding approximate quantile
|
||||||
|
*/
|
||||||
|
def query(percentile: Double): Option[Double] =
|
||||||
|
query(Seq(percentile)) match {
|
||||||
|
case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head)
|
||||||
|
case _ => None
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object QuantileSummaries {
|
object QuantileSummaries {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.util
|
package org.apache.spark.sql.catalyst.util
|
||||||
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
@ -54,10 +55,13 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
summary
|
summary
|
||||||
}
|
}
|
||||||
|
|
||||||
private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = {
|
private def validateQuantileApproximation(
|
||||||
if (data.nonEmpty) {
|
approx: Double,
|
||||||
val approx = summary.query(quant).get
|
percentile: Double,
|
||||||
// Get the rank of the approximation.
|
data: Seq[Double],
|
||||||
|
summary: QuantileSummaries): Unit = {
|
||||||
|
assert(data.nonEmpty)
|
||||||
|
|
||||||
val rankOfValue = data.count(_ <= approx)
|
val rankOfValue = data.count(_ <= approx)
|
||||||
val rankOfPreValue = data.count(_ < approx)
|
val rankOfPreValue = data.count(_ < approx)
|
||||||
// `rankOfValue` is the last position of the quantile value. If the input repeats the value
|
// `rankOfValue` is the last position of the quantile value. If the input repeats the value
|
||||||
|
@ -65,14 +69,37 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
// improper to choose the last position as its rank. Instead, we get the rank by averaging
|
// improper to choose the last position as its rank. Instead, we get the rank by averaging
|
||||||
// `rankOfValue` and `rankOfPreValue`.
|
// `rankOfValue` and `rankOfPreValue`.
|
||||||
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
|
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
|
||||||
val lower = math.floor((quant - summary.relativeError) * data.size)
|
val lower = math.floor((percentile - summary.relativeError) * data.size)
|
||||||
val upper = math.ceil((quant + summary.relativeError) * data.size)
|
val upper = math.ceil((percentile + summary.relativeError) * data.size)
|
||||||
val msg =
|
val msg =
|
||||||
s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx"
|
s"$rank not in [$lower $upper], requested percentile: $percentile, approx returned: $approx"
|
||||||
assert(rank >= lower, msg)
|
assert(rank >= lower, msg)
|
||||||
assert(rank <= upper, msg)
|
assert(rank <= upper, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkQuantile(
|
||||||
|
percentile: Double,
|
||||||
|
data: Seq[Double],
|
||||||
|
summary: QuantileSummaries): Unit = {
|
||||||
|
if (data.nonEmpty) {
|
||||||
|
val approx = summary.query(percentile).get
|
||||||
|
validateQuantileApproximation(approx, percentile, data, summary)
|
||||||
} else {
|
} else {
|
||||||
assert(summary.query(quant).isEmpty)
|
assert(summary.query(percentile).isEmpty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkQuantiles(
|
||||||
|
percentiles: Seq[Double],
|
||||||
|
data: Seq[Double],
|
||||||
|
summary: QuantileSummaries): Unit = {
|
||||||
|
if (data.nonEmpty) {
|
||||||
|
val approx = summary.query(percentiles).get
|
||||||
|
for ((q, a) <- percentiles zip approx) {
|
||||||
|
validateQuantileApproximation(a, q, data, summary)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(summary.query(percentiles).isEmpty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
checkQuantile(0.5, data, s)
|
checkQuantile(0.5, data, s)
|
||||||
checkQuantile(0.1, data, s)
|
checkQuantile(0.1, data, s)
|
||||||
checkQuantile(0.001, data, s)
|
checkQuantile(0.001, data, s)
|
||||||
|
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||||
|
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " +
|
test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " +
|
||||||
|
@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
checkQuantile(0.5, data, s)
|
checkQuantile(0.5, data, s)
|
||||||
checkQuantile(0.1, data, s)
|
checkQuantile(0.1, data, s)
|
||||||
checkQuantile(0.001, data, s)
|
checkQuantile(0.001, data, s)
|
||||||
|
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||||
|
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") {
|
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") {
|
||||||
|
@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
checkQuantile(0.5, emptyData, s)
|
checkQuantile(0.5, emptyData, s)
|
||||||
checkQuantile(0.1, emptyData, s)
|
checkQuantile(0.1, emptyData, s)
|
||||||
checkQuantile(0.001, emptyData, s)
|
checkQuantile(0.001, emptyData, s)
|
||||||
|
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s)
|
||||||
|
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
checkQuantile(0.5, data, s)
|
checkQuantile(0.5, data, s)
|
||||||
checkQuantile(0.1, data, s)
|
checkQuantile(0.1, data, s)
|
||||||
checkQuantile(0.001, data, s)
|
checkQuantile(0.001, data, s)
|
||||||
|
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||||
|
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
val (data11, data12) = {
|
val (data11, data12) = {
|
||||||
|
@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
checkQuantile(0.5, data, s)
|
checkQuantile(0.5, data, s)
|
||||||
checkQuantile(0.1, data, s)
|
checkQuantile(0.1, data, s)
|
||||||
checkQuantile(0.001, data, s)
|
checkQuantile(0.001, data, s)
|
||||||
|
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
|
||||||
|
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// length of data21 is 4 * length of data22
|
// length of data21 is 4 * length of data22
|
||||||
|
@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite {
|
||||||
val s2 = buildSummary(data22, epsi, compression)
|
val s2 = buildSummary(data22, epsi, compression)
|
||||||
val s = s1.merge(s2)
|
val s = s1.merge(s2)
|
||||||
// Check all quantiles
|
// Check all quantiles
|
||||||
|
val percentiles = ArrayBuffer[Double]()
|
||||||
for (queryRank <- 1 to n) {
|
for (queryRank <- 1 to n) {
|
||||||
val queryQuantile = queryRank.toDouble / n.toDouble
|
val percentile = queryRank.toDouble / n.toDouble
|
||||||
checkQuantile(queryQuantile, data, s)
|
checkQuantile(percentile, data, s)
|
||||||
}
|
percentiles += percentile
|
||||||
|
}
|
||||||
|
checkQuantiles(percentiles.toSeq, data, s)
|
||||||
|
checkQuantiles(percentiles.reverse.toSeq, data, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,7 +102,12 @@ object StatFunctions extends Logging {
|
||||||
}
|
}
|
||||||
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
|
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
|
||||||
|
|
||||||
summaries.map { summary => probabilities.flatMap(summary.query) }
|
summaries.map {
|
||||||
|
summary => summary.query(probabilities) match {
|
||||||
|
case Some(q) => q
|
||||||
|
case None => Seq()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Calculate the Pearson Correlation Coefficient for the given columns */
|
/** Calculate the Pearson Correlation Coefficient for the given columns */
|
||||||
|
|
|
@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
|
||||||
val e = intercept[IllegalArgumentException] {
|
val e = intercept[IllegalArgumentException] {
|
||||||
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
|
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
|
||||||
}
|
}
|
||||||
assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]"))
|
assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]"))
|
||||||
|
|
||||||
// relativeError should be non-negative
|
// relativeError should be non-negative
|
||||||
val e2 = intercept[IllegalArgumentException] {
|
val e2 = intercept[IllegalArgumentException] {
|
||||||
|
|
Loading…
Reference in a new issue