[SPARK-30938][ML][MLLIB] BinaryClassificationMetrics optimization
### What changes were proposed in this pull request? 1, avoid `Iterator.grouped(size: Int)`, which need to maintain an arraybuffer of `size` 2, keep the number of partitions in curve computation ### Why are the changes needed? 1, `BinaryClassificationMetrics` tend to fail (OOM) when `grouping=count/numBins` is too large, due to `Iterator.grouped(size: Int)` need to maintain an arraybuffer with `size` entries, however, in `BinaryClassificationMetrics` we do not need to maintain such a big array; 2, make sizes of partitions more even; This PR computes metrics more stable and a littler faster; ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes #27682 from zhengruifeng/grouped_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
parent
1383bd459a
commit
14bb639c55
|
@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.mllib.evaluation.binary._
|
||||
import org.apache.spark.rdd.{RDD, UnionRDD}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
||||
/**
|
||||
|
@ -101,10 +101,19 @@ class BinaryClassificationMetrics @Since("3.0.0") (
|
|||
@Since("1.0.0")
|
||||
def roc(): RDD[(Double, Double)] = {
|
||||
val rocCurve = createCurve(FalsePositiveRate, Recall)
|
||||
val sc = confusions.context
|
||||
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
|
||||
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
|
||||
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
|
||||
val numParts = rocCurve.getNumPartitions
|
||||
rocCurve.mapPartitionsWithIndex { case (pid, iter) =>
|
||||
if (numParts == 1) {
|
||||
require(pid == 0)
|
||||
Iterator.single((0.0, 0.0)) ++ iter ++ Iterator.single((1.0, 1.0))
|
||||
} else if (pid == 0) {
|
||||
Iterator.single((0.0, 0.0)) ++ iter
|
||||
} else if (pid == numParts - 1) {
|
||||
iter ++ Iterator.single((1.0, 1.0))
|
||||
} else {
|
||||
iter
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -124,7 +133,13 @@ class BinaryClassificationMetrics @Since("3.0.0") (
|
|||
def pr(): RDD[(Double, Double)] = {
|
||||
val prCurve = createCurve(Recall, Precision)
|
||||
val (_, firstPrecision) = prCurve.first()
|
||||
confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
|
||||
prCurve.mapPartitionsWithIndex { case (pid, iter) =>
|
||||
if (pid == 0) {
|
||||
Iterator.single((0.0, firstPrecision)) ++ iter
|
||||
} else {
|
||||
iter
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -182,28 +197,40 @@ class BinaryClassificationMetrics @Since("3.0.0") (
|
|||
val countsSize = counts.count()
|
||||
// Group the iterator into chunks of about countsSize / numBins points,
|
||||
// so that the resulting number of bins is about numBins
|
||||
var grouping = countsSize / numBins
|
||||
val grouping = countsSize / numBins
|
||||
if (grouping < 2) {
|
||||
// numBins was more than half of the size; no real point in down-sampling to bins
|
||||
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
|
||||
counts
|
||||
} else {
|
||||
if (grouping >= Int.MaxValue) {
|
||||
logWarning(
|
||||
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
|
||||
grouping = Int.MaxValue
|
||||
counts.mapPartitions { iter =>
|
||||
if (iter.hasNext) {
|
||||
var score = Double.NaN
|
||||
var agg = new BinaryLabelCounter()
|
||||
var cnt = 0L
|
||||
iter.flatMap { pair =>
|
||||
score = pair._1
|
||||
agg += pair._2
|
||||
cnt += 1
|
||||
if (cnt == grouping) {
|
||||
// The score of the combined point will be just the last one's score,
|
||||
// which is also the minimal in each chunk since all scores are already
|
||||
// sorted in descending.
|
||||
// The combined point will contain all counts in this chunk. Thus, calculated
|
||||
// metrics (like precision, recall, etc.) on its score (or so-called threshold)
|
||||
// are the same as those without sampling.
|
||||
val ret = (score, agg)
|
||||
agg = new BinaryLabelCounter()
|
||||
cnt = 0
|
||||
Some(ret)
|
||||
} else None
|
||||
} ++ {
|
||||
if (cnt > 0) {
|
||||
Iterator.single((score, agg))
|
||||
} else Iterator.empty
|
||||
}
|
||||
} else Iterator.empty
|
||||
}
|
||||
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
|
||||
// The score of the combined point will be just the last one's score, which is also
|
||||
// the minimal in each chunk since all scores are already sorted in descending.
|
||||
val lastScore = pairs.last._1
|
||||
// The combined point will contain all counts in this chunk. Thus, calculated
|
||||
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
|
||||
// the same as those without sampling.
|
||||
val agg = new BinaryLabelCounter()
|
||||
pairs.foreach(pair => agg += pair._2)
|
||||
(lastScore, agg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue