[SPARK-32455][ML] LogisticRegressionModel prediction optimization
### What changes were proposed in this pull request? for binary `LogisticRegressionModel`: 1, keep variables `_threshold` and `_rawThreshold` instead of computing them on each instance; 2, in `raw2probabilityInPlace`, make use of the characteristic that the sum of probability is 1.0; ### Why are the changes needed? for better performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuite and performace test in REPL Closes #29255 from zhengruifeng/pred_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Huaxin Gao <huaxing@us.ibm.com>
This commit is contained in:
parent
89d9b7cc64
commit
81b0785fb2
|
@ -1098,16 +1098,42 @@ class LogisticRegressionModel private[spark] (
|
|||
_intercept
|
||||
}
|
||||
|
||||
private lazy val _intercept = interceptVector.toArray.head
|
||||
private lazy val _intercept = interceptVector(0)
|
||||
private lazy val _interceptVector = interceptVector.toDense
|
||||
private var _threshold = Double.NaN
|
||||
private var _rawThreshold = Double.NaN
|
||||
|
||||
updateBinaryThreshold()
|
||||
|
||||
private def updateBinaryThreshold(): Unit = {
|
||||
if (!isMultinomial) {
|
||||
_threshold = getThreshold
|
||||
if (_threshold == 0.0) {
|
||||
_rawThreshold = Double.NegativeInfinity
|
||||
} else if (_threshold == 1.0) {
|
||||
_rawThreshold = Double.PositiveInfinity
|
||||
} else {
|
||||
_rawThreshold = math.log(_threshold / (1.0 - _threshold))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
override def setThreshold(value: Double): this.type = super.setThreshold(value)
|
||||
override def setThreshold(value: Double): this.type = {
|
||||
super.setThreshold(value)
|
||||
updateBinaryThreshold()
|
||||
this
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
override def getThreshold: Double = super.getThreshold
|
||||
|
||||
@Since("1.5.0")
|
||||
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
|
||||
override def setThresholds(value: Array[Double]): this.type = {
|
||||
super.setThresholds(value)
|
||||
updateBinaryThreshold()
|
||||
this
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
override def getThresholds: Array[Double] = super.getThresholds
|
||||
|
@ -1119,7 +1145,7 @@ class LogisticRegressionModel private[spark] (
|
|||
|
||||
/** Margin (rawPrediction) for each class label. */
|
||||
private val margins: Vector => Vector = (features) => {
|
||||
val m = interceptVector.toDense.copy
|
||||
val m = _interceptVector.copy
|
||||
BLAS.gemv(1.0, coefficientMatrix, features, 1.0, m)
|
||||
m
|
||||
}
|
||||
|
@ -1178,30 +1204,27 @@ class LogisticRegressionModel private[spark] (
|
|||
override def predict(features: Vector): Double = if (isMultinomial) {
|
||||
super.predict(features)
|
||||
} else {
|
||||
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
|
||||
if (score(features) > getThreshold) 1 else 0
|
||||
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
|
||||
if (score(features) > _threshold) 1 else 0
|
||||
}
|
||||
|
||||
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
rawPrediction match {
|
||||
case dv: DenseVector =>
|
||||
if (isMultinomial) {
|
||||
val size = dv.size
|
||||
val values = dv.values
|
||||
|
||||
if (isMultinomial) {
|
||||
// get the maximum margin
|
||||
val maxMarginIndex = rawPrediction.argmax
|
||||
val maxMargin = rawPrediction(maxMarginIndex)
|
||||
|
||||
if (maxMargin == Double.PositiveInfinity) {
|
||||
var k = 0
|
||||
while (k < size) {
|
||||
while (k < numClasses) {
|
||||
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
|
||||
k += 1
|
||||
}
|
||||
} else {
|
||||
val sum = {
|
||||
var temp = 0.0
|
||||
var sum = 0.0
|
||||
var k = 0
|
||||
while (k < numClasses) {
|
||||
values(k) = if (maxMargin > 0) {
|
||||
|
@ -1209,21 +1232,15 @@ class LogisticRegressionModel private[spark] (
|
|||
} else {
|
||||
math.exp(values(k))
|
||||
}
|
||||
temp += values(k)
|
||||
sum += values(k)
|
||||
k += 1
|
||||
}
|
||||
temp
|
||||
}
|
||||
BLAS.scal(1 / sum, dv)
|
||||
}
|
||||
dv
|
||||
} else {
|
||||
var i = 0
|
||||
val size = dv.size
|
||||
while (i < size) {
|
||||
dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
|
||||
i += 1
|
||||
}
|
||||
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
|
||||
values(1) = 1.0 - values(0)
|
||||
dv
|
||||
}
|
||||
case sv: SparseVector =>
|
||||
|
@ -1253,16 +1270,8 @@ class LogisticRegressionModel private[spark] (
|
|||
if (isMultinomial) {
|
||||
super.raw2prediction(rawPrediction)
|
||||
} else {
|
||||
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
|
||||
val t = getThreshold
|
||||
val rawThreshold = if (t == 0.0) {
|
||||
Double.NegativeInfinity
|
||||
} else if (t == 1.0) {
|
||||
Double.PositiveInfinity
|
||||
} else {
|
||||
math.log(t / (1.0 - t))
|
||||
}
|
||||
if (rawPrediction(1) > rawThreshold) 1 else 0
|
||||
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
|
||||
if (rawPrediction(1) > _rawThreshold) 1.0 else 0.0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1270,8 +1279,8 @@ class LogisticRegressionModel private[spark] (
|
|||
if (isMultinomial) {
|
||||
super.probability2prediction(probability)
|
||||
} else {
|
||||
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
|
||||
if (probability(1) > getThreshold) 1 else 0
|
||||
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
|
||||
if (probability(1) > _threshold) 1.0 else 0.0
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue