[SPARK-35678][ML] add a common softmax function

### What changes were proposed in this pull request?
add softmax function in utils

### Why are the changes needed?
it can be used in multi places

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
existing testsuites

Closes #32822 from zhengruifeng/add_softmax_func.

Authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
This commit is contained in:
Ruifeng Zheng 2021-06-15 10:33:57 +08:00
parent 2a56cc36ca
commit 8c4b535baf
3 changed files with 40 additions and 52 deletions

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.impl
import org.apache.spark.ml.linalg.BLAS
private[spark] object Utils {
@ -94,4 +95,34 @@ private[spark] object Utils {
math.log1p(math.exp(x))
}
}
/**
* Perform in-place softmax conversion.
*/
def softmax(values: Array[Double]): Unit = {
var maxValue = Double.MinValue
var i = 0
while (i < values.length) {
val value = values(i)
if (value.isPosInfinity) {
java.util.Arrays.fill(values, 0)
values(i) = 1.0
return
} else if (value > maxValue) {
maxValue = value
}
i += 1
}
var sum = 0.0
i = 0
while (i < values.length) {
val exp = math.exp(values(i) - maxValue)
values(i) = exp
sum += exp
i += 1
}
BLAS.javaBLAS.dscal(values.length, 1.0 / sum, values, 1)
}
}

View file

@ -29,6 +29,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature._
import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.optim.aggregator._
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
@ -1224,36 +1225,12 @@ class LogisticRegressionModel private[spark] (
case dv: DenseVector =>
val values = dv.values
if (isMultinomial) {
// get the maximum margin
val maxMarginIndex = rawPrediction.argmax
val maxMargin = rawPrediction(maxMarginIndex)
if (maxMargin == Double.PositiveInfinity) {
var k = 0
while (k < numClasses) {
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
k += 1
}
} else {
var sum = 0.0
var k = 0
while (k < numClasses) {
values(k) = if (maxMargin > 0) {
math.exp(values(k) - maxMargin)
} else {
math.exp(values(k))
}
sum += values(k)
k += 1
}
BLAS.scal(1 / sum, dv)
}
dv
Utils.softmax(values)
} else {
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
values(1) = 1.0 - values(0)
dv
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
" raw2probabilitiesInPlace encountered SparseVector")

View file

@ -19,6 +19,7 @@ package org.apache.spark.ml.optim.aggregator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.InstanceBlock
import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.linalg._
/**
@ -125,47 +126,26 @@ private[ml] class MultinomialLogisticBlockAggregator(
var localLossSum = 0.0
var localWeightSum = 0.0
var i = 0
val tmp = Array.ofDim[Double](numClasses)
val probs = Array.ofDim[Double](numClasses)
val multiplierSum = Array.ofDim[Double](numClasses)
while (i < size) {
val weight = block.getWeight(i)
localWeightSum += weight
if (weight > 0) {
val label = block.getLabel(i)
var maxMargin = Double.NegativeInfinity
var j = 0
while (j < numClasses) {
tmp(j) = mat(i, j)
maxMargin = math.max(maxMargin, tmp(j))
j += 1
}
// marginOfLabel is margins(label) in the formula
val marginOfLabel = tmp(label.toInt)
var sum = 0.0
j = 0
while (j < numClasses) {
if (maxMargin > 0) tmp(j) -= maxMargin
val exp = math.exp(tmp(j))
sum += exp
tmp(j) = exp
j += 1
}
while (j < numClasses) { probs(j) = mat(i, j); j += 1 }
Utils.softmax(probs)
j = 0
while (j < numClasses) {
val multiplier = weight * (tmp(j) / sum - (if (label == j) 1.0 else 0.0))
val multiplier = weight * (probs(j) - (if (label == j) 1.0 else 0.0))
mat.update(i, j, multiplier)
multiplierSum(j) += multiplier
j += 1
}
if (maxMargin > 0) {
localLossSum += weight * (math.log(sum) - marginOfLabel + maxMargin)
} else {
localLossSum += weight * (math.log(sum) - marginOfLabel)
}
localLossSum -= weight * math.log(probs(label.toInt))
} else {
var j = 0; while (j < numClasses) { mat.update(i, j, 0.0); j += 1 }
}