[SPARK-35678][ML] add a common softmax function
### What changes were proposed in this pull request? add softmax function in utils ### Why are the changes needed? it can be used in multi places ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? existing testsuites Closes #32822 from zhengruifeng/add_softmax_func. Authored-by: Ruifeng Zheng <ruifengz@foxmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
This commit is contained in:
parent
2a56cc36ca
commit
8c4b535baf
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.impl
|
package org.apache.spark.ml.impl
|
||||||
|
|
||||||
|
import org.apache.spark.ml.linalg.BLAS
|
||||||
|
|
||||||
private[spark] object Utils {
|
private[spark] object Utils {
|
||||||
|
|
||||||
|
@ -94,4 +95,34 @@ private[spark] object Utils {
|
||||||
math.log1p(math.exp(x))
|
math.log1p(math.exp(x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform in-place softmax conversion.
|
||||||
|
*/
|
||||||
|
def softmax(values: Array[Double]): Unit = {
|
||||||
|
var maxValue = Double.MinValue
|
||||||
|
var i = 0
|
||||||
|
while (i < values.length) {
|
||||||
|
val value = values(i)
|
||||||
|
if (value.isPosInfinity) {
|
||||||
|
java.util.Arrays.fill(values, 0)
|
||||||
|
values(i) = 1.0
|
||||||
|
return
|
||||||
|
} else if (value > maxValue) {
|
||||||
|
maxValue = value
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var sum = 0.0
|
||||||
|
i = 0
|
||||||
|
while (i < values.length) {
|
||||||
|
val exp = math.exp(values(i) - maxValue)
|
||||||
|
values(i) = exp
|
||||||
|
sum += exp
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
BLAS.javaBLAS.dscal(values.length, 1.0 / sum, values, 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.SparkException
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.feature._
|
import org.apache.spark.ml.feature._
|
||||||
|
import org.apache.spark.ml.impl.Utils
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
import org.apache.spark.ml.optim.aggregator._
|
import org.apache.spark.ml.optim.aggregator._
|
||||||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||||
|
@ -1224,36 +1225,12 @@ class LogisticRegressionModel private[spark] (
|
||||||
case dv: DenseVector =>
|
case dv: DenseVector =>
|
||||||
val values = dv.values
|
val values = dv.values
|
||||||
if (isMultinomial) {
|
if (isMultinomial) {
|
||||||
// get the maximum margin
|
Utils.softmax(values)
|
||||||
val maxMarginIndex = rawPrediction.argmax
|
|
||||||
val maxMargin = rawPrediction(maxMarginIndex)
|
|
||||||
|
|
||||||
if (maxMargin == Double.PositiveInfinity) {
|
|
||||||
var k = 0
|
|
||||||
while (k < numClasses) {
|
|
||||||
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
|
|
||||||
k += 1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var sum = 0.0
|
|
||||||
var k = 0
|
|
||||||
while (k < numClasses) {
|
|
||||||
values(k) = if (maxMargin > 0) {
|
|
||||||
math.exp(values(k) - maxMargin)
|
|
||||||
} else {
|
|
||||||
math.exp(values(k))
|
|
||||||
}
|
|
||||||
sum += values(k)
|
|
||||||
k += 1
|
|
||||||
}
|
|
||||||
BLAS.scal(1 / sum, dv)
|
|
||||||
}
|
|
||||||
dv
|
|
||||||
} else {
|
} else {
|
||||||
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
|
values(0) = 1.0 / (1.0 + math.exp(-values(0)))
|
||||||
values(1) = 1.0 - values(0)
|
values(1) = 1.0 - values(0)
|
||||||
dv
|
|
||||||
}
|
}
|
||||||
|
dv
|
||||||
case sv: SparseVector =>
|
case sv: SparseVector =>
|
||||||
throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
|
throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
|
||||||
" raw2probabilitiesInPlace encountered SparseVector")
|
" raw2probabilitiesInPlace encountered SparseVector")
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.optim.aggregator
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.feature.InstanceBlock
|
import org.apache.spark.ml.feature.InstanceBlock
|
||||||
|
import org.apache.spark.ml.impl.Utils
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -125,47 +126,26 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
||||||
var localLossSum = 0.0
|
var localLossSum = 0.0
|
||||||
var localWeightSum = 0.0
|
var localWeightSum = 0.0
|
||||||
var i = 0
|
var i = 0
|
||||||
val tmp = Array.ofDim[Double](numClasses)
|
val probs = Array.ofDim[Double](numClasses)
|
||||||
val multiplierSum = Array.ofDim[Double](numClasses)
|
val multiplierSum = Array.ofDim[Double](numClasses)
|
||||||
while (i < size) {
|
while (i < size) {
|
||||||
val weight = block.getWeight(i)
|
val weight = block.getWeight(i)
|
||||||
localWeightSum += weight
|
localWeightSum += weight
|
||||||
if (weight > 0) {
|
if (weight > 0) {
|
||||||
val label = block.getLabel(i)
|
val label = block.getLabel(i)
|
||||||
var maxMargin = Double.NegativeInfinity
|
|
||||||
var j = 0
|
var j = 0
|
||||||
while (j < numClasses) {
|
while (j < numClasses) { probs(j) = mat(i, j); j += 1 }
|
||||||
tmp(j) = mat(i, j)
|
Utils.softmax(probs)
|
||||||
maxMargin = math.max(maxMargin, tmp(j))
|
|
||||||
j += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// marginOfLabel is margins(label) in the formula
|
|
||||||
val marginOfLabel = tmp(label.toInt)
|
|
||||||
|
|
||||||
var sum = 0.0
|
|
||||||
j = 0
|
|
||||||
while (j < numClasses) {
|
|
||||||
if (maxMargin > 0) tmp(j) -= maxMargin
|
|
||||||
val exp = math.exp(tmp(j))
|
|
||||||
sum += exp
|
|
||||||
tmp(j) = exp
|
|
||||||
j += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
j = 0
|
j = 0
|
||||||
while (j < numClasses) {
|
while (j < numClasses) {
|
||||||
val multiplier = weight * (tmp(j) / sum - (if (label == j) 1.0 else 0.0))
|
val multiplier = weight * (probs(j) - (if (label == j) 1.0 else 0.0))
|
||||||
mat.update(i, j, multiplier)
|
mat.update(i, j, multiplier)
|
||||||
multiplierSum(j) += multiplier
|
multiplierSum(j) += multiplier
|
||||||
j += 1
|
j += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if (maxMargin > 0) {
|
localLossSum -= weight * math.log(probs(label.toInt))
|
||||||
localLossSum += weight * (math.log(sum) - marginOfLabel + maxMargin)
|
|
||||||
} else {
|
|
||||||
localLossSum += weight * (math.log(sum) - marginOfLabel)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
var j = 0; while (j < numClasses) { mat.update(i, j, 0.0); j += 1 }
|
var j = 0; while (j < numClasses) { mat.update(i, j, 0.0); j += 1 }
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue