[SPARK-35558] Optimizes for multi-quantile retrieval

### What changes were proposed in this pull request?
Optimizes the retrieval of approximate quantiles for an array of percentiles.
* Adds an overload for QuantileSummaries.query that accepts an array of percentiles and optimizes the computation to do a single pass over the sketch and avoid redundant computation.
* Modifies the ApproximatePercentiles operator to call into the new method.

All formatting changes are the result of running ./dev/scalafmt

### Why are the changes needed?
The existing implementation does repeated calls per input percentile resulting in redundant computation.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added unit tests for the new method.

Closes #32700 from alkispoly-db/spark_35558_approx_quants_array.

Authored-by: Alkis Polyzotis <alkis.polyzotis@databricks.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Alkis Polyzotis 2021-06-05 14:25:33 -05:00 committed by Sean Owen
parent 510bde460a
commit 6f8c62047c
5 changed files with 149 additions and 57 deletions

View file

@ -261,19 +261,12 @@ object ApproximatePercentile {
* val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75)) * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
* }}} * }}}
*/ */
def getPercentiles(percentages: Array[Double]): Array[Double] = { def getPercentiles(percentages: Array[Double]): Seq[Double] = {
if (!isCompressed) compress() if (!isCompressed) compress()
if (summaries.count == 0 || percentages.length == 0) { if (summaries.count == 0 || percentages.length == 0) {
Array.emptyDoubleArray Array.emptyDoubleArray
} else { } else {
val result = new Array[Double](percentages.length) summaries.query(percentages).get
var i = 0
while (i < percentages.length) {
// Since summaries.count != 0, the query here never return None.
result(i) = summaries.query(percentages(i)).get
i += 1
}
result
} }
} }

View file

@ -229,46 +229,99 @@ class QuantileSummaries(
} }
/** /**
* Runs a query for a given quantile. * Finds the approximate quantile for a percentile, starting at a specific index in the summary.
* This is a helper method that is called as we are making a pass over the summary and a sorted
* sequence of input percentiles.
*
* @param index The point at which to start scanning the summary for an approximate value.
* @param minRankAtIndex The accumulated minimum rank at the given index.
* @param targetError Target error from the summary.
* @param percentile The percentile whose value is computed.
* @return A tuple (i, r, a) where: i is the updated index for the next call, r is the updated
* rank at i, and a is the approximate quantile.
*/
private def findApproxQuantile(
index: Int,
minRankAtIndex: Long,
targetError: Double,
percentile: Double): (Int, Long, Double) = {
var curSample = sampled(index)
val rank = math.ceil(percentile * count).toLong
var i = index
var minRank = minRankAtIndex
while (i < sampled.length - 1) {
val maxRank = minRank + curSample.delta
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
return (i, minRank, curSample.value)
} else {
i += 1
curSample = sampled(i)
minRank += curSample.g
}
}
(sampled.length - 1, 0, sampled.last.value)
}
/**
* Runs a query for a given sequence of percentiles.
* The result follows the approximation guarantees detailed above. * The result follows the approximation guarantees detailed above.
* The query can only be run on a compressed summary: you need to call compress() before using * The query can only be run on a compressed summary: you need to call compress() before using
* it. * it.
* *
* @param quantile the target quantile * @param percentiles the target percentiles
* @return * @return the corresponding approximate quantiles, in the same order as the input
*/ */
def query(quantile: Double): Option[Double] = { def query(percentiles: Seq[Double]): Option[Seq[Double]] = {
require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") percentiles.foreach(p =>
require(headSampled.isEmpty, require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0, 1.0]"))
require(
headSampled.isEmpty,
"Cannot operate on an uncompressed summary, call compress() first") "Cannot operate on an uncompressed summary, call compress() first")
if (sampled.isEmpty) return None if (sampled.isEmpty) return None
if (quantile <= relativeError) { val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) =>
return Some(sampled.head.value) currentMax.max(stats.delta + stats.g)) / 2
}
if (quantile >= 1 - relativeError) { // Index to track the current sample
return Some(sampled.last.value) var index = 0
}
// Target rank
val rank = math.ceil(quantile * count).toLong
val targetError = sampled.map(s => s.delta + s.g).max / 2
// Minimum rank at current sample // Minimum rank at current sample
var minRank = 0L var minRank = sampled(0).g
var i = 0
while (i < sampled.length - 1) { val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1)
val curSample = sampled(i) val result = Array.fill(percentiles.length)(0.0)
minRank += curSample.g sortedPercentiles.foreach {
val maxRank = minRank + curSample.delta case (percentile, pos) =>
if (maxRank - targetError <= rank && rank <= minRank + targetError) { if (percentile <= relativeError) {
return Some(curSample.value) result(pos) = sampled.head.value
} else if (percentile >= 1 - relativeError) {
result(pos) = sampled.last.value
} else {
val (newIndex, newMinRank, approxQuantile) =
findApproxQuantile(index, minRank, targetError, percentile)
index = newIndex
minRank = newMinRank
result(pos) = approxQuantile
} }
i += 1
} }
Some(sampled.last.value) Some(result)
} }
/**
* Runs a query for a given percentile.
* The result follows the approximation guarantees detailed above.
* The query can only be run on a compressed summary: you need to call compress() before using
* it.
*
* @param percentile the target percentile
* @return the corresponding approximate quantile
*/
def query(percentile: Double): Option[Double] =
query(Seq(percentile)) match {
case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head)
case _ => None
}
} }
object QuantileSummaries { object QuantileSummaries {

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.util package org.apache.spark.sql.catalyst.util
import scala.collection.mutable.ArrayBuffer
import scala.util.Random import scala.util.Random
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
@ -54,10 +55,13 @@ class QuantileSummariesSuite extends SparkFunSuite {
summary summary
} }
private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { private def validateQuantileApproximation(
if (data.nonEmpty) { approx: Double,
val approx = summary.query(quant).get percentile: Double,
// Get the rank of the approximation. data: Seq[Double],
summary: QuantileSummaries): Unit = {
assert(data.nonEmpty)
val rankOfValue = data.count(_ <= approx) val rankOfValue = data.count(_ <= approx)
val rankOfPreValue = data.count(_ < approx) val rankOfPreValue = data.count(_ < approx)
// `rankOfValue` is the last position of the quantile value. If the input repeats the value // `rankOfValue` is the last position of the quantile value. If the input repeats the value
@ -65,14 +69,37 @@ class QuantileSummariesSuite extends SparkFunSuite {
// improper to choose the last position as its rank. Instead, we get the rank by averaging // improper to choose the last position as its rank. Instead, we get the rank by averaging
// `rankOfValue` and `rankOfPreValue`. // `rankOfValue` and `rankOfPreValue`.
val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
val lower = math.floor((quant - summary.relativeError) * data.size) val lower = math.floor((percentile - summary.relativeError) * data.size)
val upper = math.ceil((quant + summary.relativeError) * data.size) val upper = math.ceil((percentile + summary.relativeError) * data.size)
val msg = val msg =
s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" s"$rank not in [$lower $upper], requested percentile: $percentile, approx returned: $approx"
assert(rank >= lower, msg) assert(rank >= lower, msg)
assert(rank <= upper, msg) assert(rank <= upper, msg)
}
private def checkQuantile(
percentile: Double,
data: Seq[Double],
summary: QuantileSummaries): Unit = {
if (data.nonEmpty) {
val approx = summary.query(percentile).get
validateQuantileApproximation(approx, percentile, data, summary)
} else { } else {
assert(summary.query(quant).isEmpty) assert(summary.query(percentile).isEmpty)
}
}
private def checkQuantiles(
percentiles: Seq[Double],
data: Seq[Double],
summary: QuantileSummaries): Unit = {
if (data.nonEmpty) {
val approx = summary.query(percentiles).get
for ((q, a) <- percentiles zip approx) {
validateQuantileApproximation(a, q, data, summary)
}
} else {
assert(summary.query(percentiles).isEmpty)
} }
} }
@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s) checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s) checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s) checkQuantile(0.001, data, s)
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
} }
test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " + test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " +
@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s) checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s) checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s) checkQuantile(0.001, data, s)
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
} }
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") { test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") {
@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, emptyData, s) checkQuantile(0.5, emptyData, s)
checkQuantile(0.1, emptyData, s) checkQuantile(0.1, emptyData, s)
checkQuantile(0.001, emptyData, s) checkQuantile(0.001, emptyData, s)
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s)
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s)
} }
} }
@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s) checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s) checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s) checkQuantile(0.001, data, s)
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
} }
val (data11, data12) = { val (data11, data12) = {
@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s) checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s) checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s) checkQuantile(0.001, data, s)
checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
} }
// length of data21 is 4 * length of data22 // length of data21 is 4 * length of data22
@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite {
val s2 = buildSummary(data22, epsi, compression) val s2 = buildSummary(data22, epsi, compression)
val s = s1.merge(s2) val s = s1.merge(s2)
// Check all quantiles // Check all quantiles
val percentiles = ArrayBuffer[Double]()
for (queryRank <- 1 to n) { for (queryRank <- 1 to n) {
val queryQuantile = queryRank.toDouble / n.toDouble val percentile = queryRank.toDouble / n.toDouble
checkQuantile(queryQuantile, data, s) checkQuantile(percentile, data, s)
} percentiles += percentile
}
checkQuantiles(percentiles.toSeq, data, s)
checkQuantiles(percentiles.reverse.toSeq, data, s)
} }
} }
} }

View file

@ -102,7 +102,12 @@ object StatFunctions extends Logging {
} }
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge) val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
summaries.map { summary => probabilities.flatMap(summary.query) } summaries.map {
summary => summary.query(probabilities) match {
case Some(q) => q
case None => Seq()
}
}
} }
/** Calculate the Pearson Correlation Coefficient for the given columns */ /** Calculate the Pearson Correlation Coefficient for the given columns */

View file

@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
val e = intercept[IllegalArgumentException] { val e = intercept[IllegalArgumentException] {
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head) df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), epsilons.head)
} }
assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]")) assert(e.getMessage.contains("percentile should be in the range [0.0, 1.0]"))
// relativeError should be non-negative // relativeError should be non-negative
val e2 = intercept[IllegalArgumentException] { val e2 = intercept[IllegalArgumentException] {