[SPARK-32455][ML] LogisticRegressionModel prediction optimization

### What changes were proposed in this pull request?
for binary `LogisticRegressionModel`:
1, keep variables `_threshold` and `_rawThreshold` instead of computing them on each instance;
2, in `raw2probabilityInPlace`, make use of the characteristic that the sum of probability is 1.0;

### Why are the changes needed?
for better performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing testsuite and performace test in REPL

Closes #29255 from zhengruifeng/pred_opt.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Huaxin Gao <huaxing@us.ibm.com>
This commit is contained in:
zhengruifeng 2020-07-29 19:53:28 -07:00 committed by Huaxin Gao
parent 89d9b7cc64
commit 81b0785fb2

View file

@ -1098,16 +1098,42 @@ class LogisticRegressionModel private[spark] (
_intercept
}
private lazy val _intercept = interceptVector.toArray.head
private lazy val _intercept = interceptVector(0)
private lazy val _interceptVector = interceptVector.toDense
private var _threshold = Double.NaN
private var _rawThreshold = Double.NaN
updateBinaryThreshold()
private def updateBinaryThreshold(): Unit = {
if (!isMultinomial) {
_threshold = getThreshold
if (_threshold == 0.0) {
_rawThreshold = Double.NegativeInfinity
} else if (_threshold == 1.0) {
_rawThreshold = Double.PositiveInfinity
} else {
_rawThreshold = math.log(_threshold / (1.0 - _threshold))
}
}
}
@Since("1.5.0")
override def setThreshold(value: Double): this.type = super.setThreshold(value)
override def setThreshold(value: Double): this.type = {
super.setThreshold(value)
updateBinaryThreshold()
this
}
@Since("1.5.0")
override def getThreshold: Double = super.getThreshold
@Since("1.5.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
override def setThresholds(value: Array[Double]): this.type = {
super.setThresholds(value)
updateBinaryThreshold()
this
}
@Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds
@ -1119,7 +1145,7 @@ class LogisticRegressionModel private[spark] (
/** Margin (rawPrediction) for each class label. */
private val margins: Vector => Vector = (features) => {
val m = interceptVector.toDense.copy
val m = _interceptVector.copy
BLAS.gemv(1.0, coefficientMatrix, features, 1.0, m)
m
}
@ -1178,52 +1204,43 @@ class LogisticRegressionModel private[spark] (
override def predict(features: Vector): Double = if (isMultinomial) {
super.predict(features)
} else {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (score(features) > getThreshold) 1 else 0
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
if (score(features) > _threshold) 1 else 0
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
val values = dv.values
if (isMultinomial) {
val size = dv.size
val values = dv.values
// get the maximum margin
val maxMarginIndex = rawPrediction.argmax
val maxMargin = rawPrediction(maxMarginIndex)
if (maxMargin == Double.PositiveInfinity) {
var k = 0
while (k < size) {
while (k < numClasses) {
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
k += 1
}
} else {
val sum = {
var temp = 0.0
var k = 0
while (k < numClasses) {
values(k) = if (maxMargin > 0) {
math.exp(values(k) - maxMargin)
} else {
math.exp(values(k))
}
temp += values(k)
k += 1
var sum = 0.0
var k = 0
while (k < numClasses) {
values(k) = if (maxMargin > 0) {
math.exp(values(k) - maxMargin)
} else {
math.exp(values(k))
}
temp
sum += values(k)
k += 1
}
BLAS.scal(1 / sum, dv)
}
dv
} else {
var i = 0
val size = dv.size
while (i < size) {
dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
i += 1
}
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
values(1) = 1.0 - values(0)
dv
}
case sv: SparseVector =>
@ -1253,16 +1270,8 @@ class LogisticRegressionModel private[spark] (
if (isMultinomial) {
super.raw2prediction(rawPrediction)
} else {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
val t = getThreshold
val rawThreshold = if (t == 0.0) {
Double.NegativeInfinity
} else if (t == 1.0) {
Double.PositiveInfinity
} else {
math.log(t / (1.0 - t))
}
if (rawPrediction(1) > rawThreshold) 1 else 0
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
if (rawPrediction(1) > _rawThreshold) 1.0 else 0.0
}
}
@ -1270,8 +1279,8 @@ class LogisticRegressionModel private[spark] (
if (isMultinomial) {
super.probability2prediction(probability)
} else {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > _threshold) 1.0 else 0.0
}
}