[SPARK-35678][ML] add a common softmax function
### What changes were proposed in this pull request? add softmax function in utils ### Why are the changes needed? it can be used in multi places ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? existing testsuites Closes #32822 from zhengruifeng/add_softmax_func. Authored-by: Ruifeng Zheng <ruifengz@foxmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
This commit is contained in:
parent
2a56cc36ca
commit
8c4b535baf
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.ml.impl
|
||||
|
||||
import org.apache.spark.ml.linalg.BLAS
|
||||
|
||||
private[spark] object Utils {
|
||||
|
||||
|
@ -94,4 +95,34 @@ private[spark] object Utils {
|
|||
math.log1p(math.exp(x))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform in-place softmax conversion.
|
||||
*/
|
||||
def softmax(values: Array[Double]): Unit = {
|
||||
var maxValue = Double.MinValue
|
||||
var i = 0
|
||||
while (i < values.length) {
|
||||
val value = values(i)
|
||||
if (value.isPosInfinity) {
|
||||
java.util.Arrays.fill(values, 0)
|
||||
values(i) = 1.0
|
||||
return
|
||||
} else if (value > maxValue) {
|
||||
maxValue = value
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
||||
var sum = 0.0
|
||||
i = 0
|
||||
while (i < values.length) {
|
||||
val exp = math.exp(values(i) - maxValue)
|
||||
values(i) = exp
|
||||
sum += exp
|
||||
i += 1
|
||||
}
|
||||
|
||||
BLAS.javaBLAS.dscal(values.length, 1.0 / sum, values, 1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.SparkException
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.impl.Utils
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.optim.aggregator._
|
||||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||
|
@ -1224,36 +1225,12 @@ class LogisticRegressionModel private[spark] (
|
|||
case dv: DenseVector =>
|
||||
val values = dv.values
|
||||
if (isMultinomial) {
|
||||
// get the maximum margin
|
||||
val maxMarginIndex = rawPrediction.argmax
|
||||
val maxMargin = rawPrediction(maxMarginIndex)
|
||||
|
||||
if (maxMargin == Double.PositiveInfinity) {
|
||||
var k = 0
|
||||
while (k < numClasses) {
|
||||
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
|
||||
k += 1
|
||||
}
|
||||
} else {
|
||||
var sum = 0.0
|
||||
var k = 0
|
||||
while (k < numClasses) {
|
||||
values(k) = if (maxMargin > 0) {
|
||||
math.exp(values(k) - maxMargin)
|
||||
} else {
|
||||
math.exp(values(k))
|
||||
}
|
||||
sum += values(k)
|
||||
k += 1
|
||||
}
|
||||
BLAS.scal(1 / sum, dv)
|
||||
}
|
||||
dv
|
||||
Utils.softmax(values)
|
||||
} else {
|
||||
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
|
||||
values(1) = 1.0 - values(0)
|
||||
dv
|
||||
}
|
||||
dv
|
||||
case sv: SparseVector =>
|
||||
throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
|
||||
" raw2probabilitiesInPlace encountered SparseVector")
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.optim.aggregator
|
|||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.feature.InstanceBlock
|
||||
import org.apache.spark.ml.impl.Utils
|
||||
import org.apache.spark.ml.linalg._
|
||||
|
||||
/**
|
||||
|
@ -125,47 +126,26 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
var localLossSum = 0.0
|
||||
var localWeightSum = 0.0
|
||||
var i = 0
|
||||
val tmp = Array.ofDim[Double](numClasses)
|
||||
val probs = Array.ofDim[Double](numClasses)
|
||||
val multiplierSum = Array.ofDim[Double](numClasses)
|
||||
while (i < size) {
|
||||
val weight = block.getWeight(i)
|
||||
localWeightSum += weight
|
||||
if (weight > 0) {
|
||||
val label = block.getLabel(i)
|
||||
var maxMargin = Double.NegativeInfinity
|
||||
var j = 0
|
||||
while (j < numClasses) {
|
||||
tmp(j) = mat(i, j)
|
||||
maxMargin = math.max(maxMargin, tmp(j))
|
||||
j += 1
|
||||
}
|
||||
|
||||
// marginOfLabel is margins(label) in the formula
|
||||
val marginOfLabel = tmp(label.toInt)
|
||||
|
||||
var sum = 0.0
|
||||
j = 0
|
||||
while (j < numClasses) {
|
||||
if (maxMargin > 0) tmp(j) -= maxMargin
|
||||
val exp = math.exp(tmp(j))
|
||||
sum += exp
|
||||
tmp(j) = exp
|
||||
j += 1
|
||||
}
|
||||
while (j < numClasses) { probs(j) = mat(i, j); j += 1 }
|
||||
Utils.softmax(probs)
|
||||
|
||||
j = 0
|
||||
while (j < numClasses) {
|
||||
val multiplier = weight * (tmp(j) / sum - (if (label == j) 1.0 else 0.0))
|
||||
val multiplier = weight * (probs(j) - (if (label == j) 1.0 else 0.0))
|
||||
mat.update(i, j, multiplier)
|
||||
multiplierSum(j) += multiplier
|
||||
j += 1
|
||||
}
|
||||
|
||||
if (maxMargin > 0) {
|
||||
localLossSum += weight * (math.log(sum) - marginOfLabel + maxMargin)
|
||||
} else {
|
||||
localLossSum += weight * (math.log(sum) - marginOfLabel)
|
||||
}
|
||||
localLossSum -= weight * math.log(probs(label.toInt))
|
||||
} else {
|
||||
var j = 0; while (j < numClasses) { mat.update(i, j, 0.0); j += 1 }
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue