[SPARK-32455][ML] LogisticRegressionModel prediction optimization
### What changes were proposed in this pull request? for binary `LogisticRegressionModel`: 1, keep variables `_threshold` and `_rawThreshold` instead of computing them on each instance; 2, in `raw2probabilityInPlace`, make use of the characteristic that the sum of probability is 1.0; ### Why are the changes needed? for better performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuite and performace test in REPL Closes #29255 from zhengruifeng/pred_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Huaxin Gao <huaxing@us.ibm.com>
This commit is contained in:
parent
89d9b7cc64
commit
81b0785fb2
|
@ -1098,16 +1098,42 @@ class LogisticRegressionModel private[spark] (
|
||||||
_intercept
|
_intercept
|
||||||
}
|
}
|
||||||
|
|
||||||
private lazy val _intercept = interceptVector.toArray.head
|
private lazy val _intercept = interceptVector(0)
|
||||||
|
private lazy val _interceptVector = interceptVector.toDense
|
||||||
|
private var _threshold = Double.NaN
|
||||||
|
private var _rawThreshold = Double.NaN
|
||||||
|
|
||||||
|
updateBinaryThreshold()
|
||||||
|
|
||||||
|
private def updateBinaryThreshold(): Unit = {
|
||||||
|
if (!isMultinomial) {
|
||||||
|
_threshold = getThreshold
|
||||||
|
if (_threshold == 0.0) {
|
||||||
|
_rawThreshold = Double.NegativeInfinity
|
||||||
|
} else if (_threshold == 1.0) {
|
||||||
|
_rawThreshold = Double.PositiveInfinity
|
||||||
|
} else {
|
||||||
|
_rawThreshold = math.log(_threshold / (1.0 - _threshold))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def setThreshold(value: Double): this.type = super.setThreshold(value)
|
override def setThreshold(value: Double): this.type = {
|
||||||
|
super.setThreshold(value)
|
||||||
|
updateBinaryThreshold()
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def getThreshold: Double = super.getThreshold
|
override def getThreshold: Double = super.getThreshold
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
|
override def setThresholds(value: Array[Double]): this.type = {
|
||||||
|
super.setThresholds(value)
|
||||||
|
updateBinaryThreshold()
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def getThresholds: Array[Double] = super.getThresholds
|
override def getThresholds: Array[Double] = super.getThresholds
|
||||||
|
@ -1119,7 +1145,7 @@ class LogisticRegressionModel private[spark] (
|
||||||
|
|
||||||
/** Margin (rawPrediction) for each class label. */
|
/** Margin (rawPrediction) for each class label. */
|
||||||
private val margins: Vector => Vector = (features) => {
|
private val margins: Vector => Vector = (features) => {
|
||||||
val m = interceptVector.toDense.copy
|
val m = _interceptVector.copy
|
||||||
BLAS.gemv(1.0, coefficientMatrix, features, 1.0, m)
|
BLAS.gemv(1.0, coefficientMatrix, features, 1.0, m)
|
||||||
m
|
m
|
||||||
}
|
}
|
||||||
|
@ -1178,52 +1204,43 @@ class LogisticRegressionModel private[spark] (
|
||||||
override def predict(features: Vector): Double = if (isMultinomial) {
|
override def predict(features: Vector): Double = if (isMultinomial) {
|
||||||
super.predict(features)
|
super.predict(features)
|
||||||
} else {
|
} else {
|
||||||
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
|
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
|
||||||
if (score(features) > getThreshold) 1 else 0
|
if (score(features) > _threshold) 1 else 0
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||||
rawPrediction match {
|
rawPrediction match {
|
||||||
case dv: DenseVector =>
|
case dv: DenseVector =>
|
||||||
|
val values = dv.values
|
||||||
if (isMultinomial) {
|
if (isMultinomial) {
|
||||||
val size = dv.size
|
|
||||||
val values = dv.values
|
|
||||||
|
|
||||||
// get the maximum margin
|
// get the maximum margin
|
||||||
val maxMarginIndex = rawPrediction.argmax
|
val maxMarginIndex = rawPrediction.argmax
|
||||||
val maxMargin = rawPrediction(maxMarginIndex)
|
val maxMargin = rawPrediction(maxMarginIndex)
|
||||||
|
|
||||||
if (maxMargin == Double.PositiveInfinity) {
|
if (maxMargin == Double.PositiveInfinity) {
|
||||||
var k = 0
|
var k = 0
|
||||||
while (k < size) {
|
while (k < numClasses) {
|
||||||
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
|
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
|
||||||
k += 1
|
k += 1
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
val sum = {
|
var sum = 0.0
|
||||||
var temp = 0.0
|
var k = 0
|
||||||
var k = 0
|
while (k < numClasses) {
|
||||||
while (k < numClasses) {
|
values(k) = if (maxMargin > 0) {
|
||||||
values(k) = if (maxMargin > 0) {
|
math.exp(values(k) - maxMargin)
|
||||||
math.exp(values(k) - maxMargin)
|
} else {
|
||||||
} else {
|
math.exp(values(k))
|
||||||
math.exp(values(k))
|
|
||||||
}
|
|
||||||
temp += values(k)
|
|
||||||
k += 1
|
|
||||||
}
|
}
|
||||||
temp
|
sum += values(k)
|
||||||
|
k += 1
|
||||||
}
|
}
|
||||||
BLAS.scal(1 / sum, dv)
|
BLAS.scal(1 / sum, dv)
|
||||||
}
|
}
|
||||||
dv
|
dv
|
||||||
} else {
|
} else {
|
||||||
var i = 0
|
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
|
||||||
val size = dv.size
|
values(1) = 1.0 - values(0)
|
||||||
while (i < size) {
|
|
||||||
dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
dv
|
dv
|
||||||
}
|
}
|
||||||
case sv: SparseVector =>
|
case sv: SparseVector =>
|
||||||
|
@ -1253,16 +1270,8 @@ class LogisticRegressionModel private[spark] (
|
||||||
if (isMultinomial) {
|
if (isMultinomial) {
|
||||||
super.raw2prediction(rawPrediction)
|
super.raw2prediction(rawPrediction)
|
||||||
} else {
|
} else {
|
||||||
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
|
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
|
||||||
val t = getThreshold
|
if (rawPrediction(1) > _rawThreshold) 1.0 else 0.0
|
||||||
val rawThreshold = if (t == 0.0) {
|
|
||||||
Double.NegativeInfinity
|
|
||||||
} else if (t == 1.0) {
|
|
||||||
Double.PositiveInfinity
|
|
||||||
} else {
|
|
||||||
math.log(t / (1.0 - t))
|
|
||||||
}
|
|
||||||
if (rawPrediction(1) > rawThreshold) 1 else 0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1270,8 +1279,8 @@ class LogisticRegressionModel private[spark] (
|
||||||
if (isMultinomial) {
|
if (isMultinomial) {
|
||||||
super.probability2prediction(probability)
|
super.probability2prediction(probability)
|
||||||
} else {
|
} else {
|
||||||
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
|
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
|
||||||
if (probability(1) > getThreshold) 1 else 0
|
if (probability(1) > _threshold) 1.0 else 0.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue