[SPARK-30938][ML][MLLIB] BinaryClassificationMetrics optimization

### What changes were proposed in this pull request?
1, avoid `Iterator.grouped(size: Int)`, which need to maintain an arraybuffer of `size`
2, keep the number of partitions in curve computation

### Why are the changes needed?
1, `BinaryClassificationMetrics` tend to fail (OOM) when `grouping=count/numBins` is too large, due to `Iterator.grouped(size: Int)` need to maintain an arraybuffer with `size` entries, however, in `BinaryClassificationMetrics` we do not need to maintain such a big array;
2, make sizes of partitions more even;

This PR computes metrics more stable and a littler faster;

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
existing testsuites

Closes #27682 from zhengruifeng/grouped_opt.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
zhengruifeng 2020-02-28 16:55:24 +08:00
parent 1383bd459a
commit 14bb639c55

View file

@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
/**
@ -101,10 +101,19 @@ class BinaryClassificationMetrics @Since("3.0.0") (
@Since("1.0.0")
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
val numParts = rocCurve.getNumPartitions
rocCurve.mapPartitionsWithIndex { case (pid, iter) =>
if (numParts == 1) {
require(pid == 0)
Iterator.single((0.0, 0.0)) ++ iter ++ Iterator.single((1.0, 1.0))
} else if (pid == 0) {
Iterator.single((0.0, 0.0)) ++ iter
} else if (pid == numParts - 1) {
iter ++ Iterator.single((1.0, 1.0))
} else {
iter
}
}
}
/**
@ -124,7 +133,13 @@ class BinaryClassificationMetrics @Since("3.0.0") (
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val (_, firstPrecision) = prCurve.first()
confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
prCurve.mapPartitionsWithIndex { case (pid, iter) =>
if (pid == 0) {
Iterator.single((0.0, firstPrecision)) ++ iter
} else {
iter
}
}
}
/**
@ -182,28 +197,40 @@ class BinaryClassificationMetrics @Since("3.0.0") (
val countsSize = counts.count()
// Group the iterator into chunks of about countsSize / numBins points,
// so that the resulting number of bins is about numBins
var grouping = countsSize / numBins
val grouping = countsSize / numBins
if (grouping < 2) {
// numBins was more than half of the size; no real point in down-sampling to bins
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
counts
} else {
if (grouping >= Int.MaxValue) {
logWarning(
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
grouping = Int.MaxValue
counts.mapPartitions { iter =>
if (iter.hasNext) {
var score = Double.NaN
var agg = new BinaryLabelCounter()
var cnt = 0L
iter.flatMap { pair =>
score = pair._1
agg += pair._2
cnt += 1
if (cnt == grouping) {
// The score of the combined point will be just the last one's score,
// which is also the minimal in each chunk since all scores are already
// sorted in descending.
// The combined point will contain all counts in this chunk. Thus, calculated
// metrics (like precision, recall, etc.) on its score (or so-called threshold)
// are the same as those without sampling.
val ret = (score, agg)
agg = new BinaryLabelCounter()
cnt = 0
Some(ret)
} else None
} ++ {
if (cnt > 0) {
Iterator.single((score, agg))
} else Iterator.empty
}
} else Iterator.empty
}
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
// The score of the combined point will be just the last one's score, which is also
// the minimal in each chunk since all scores are already sorted in descending.
val lastScore = pairs.last._1
// The combined point will contain all counts in this chunk. Thus, calculated
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
// the same as those without sampling.
val agg = new BinaryLabelCounter()
pairs.foreach(pair => agg += pair._2)
(lastScore, agg)
})
}
}