[SPARK-35678][ML][FOLLOWUP] softmax support offset and step
### What changes were proposed in this pull request? softmax support offset and step, then we can use it in ANN and NB ### Why are the changes needed? to simplify impl ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuite Closes #32991 from zhengruifeng/softmax_support_offset_step. Authored-by: Ruifeng Zheng <ruifengz@foxmail.com> Signed-off-by: Huaxin Gao <huaxin_gao@apple.com>
This commit is contained in:
parent
be9089731a
commit
a66738823c
|
@ -99,30 +99,42 @@ private[spark] object Utils {
|
|||
/**
|
||||
* Perform in-place softmax conversion.
|
||||
*/
|
||||
def softmax(values: Array[Double]): Unit = {
|
||||
def softmax(array: Array[Double]): Unit =
|
||||
softmax(array, array.length, 0, 1, array)
|
||||
|
||||
/**
|
||||
* Perform softmax conversion.
|
||||
*/
|
||||
def softmax(
|
||||
input: Array[Double],
|
||||
n: Int,
|
||||
offset: Int,
|
||||
step: Int,
|
||||
output: Array[Double]): Unit = {
|
||||
var maxValue = Double.MinValue
|
||||
var i = 0
|
||||
while (i < values.length) {
|
||||
val value = values(i)
|
||||
if (value.isPosInfinity) {
|
||||
java.util.Arrays.fill(values, 0)
|
||||
values(i) = 1.0
|
||||
var i = offset
|
||||
val end = offset + step * n
|
||||
while (i < end) {
|
||||
val v = input(i)
|
||||
if (v.isPosInfinity) {
|
||||
BLAS.javaBLAS.dscal(n, 0.0, output, offset, step)
|
||||
output(i) = 1.0
|
||||
return
|
||||
} else if (value > maxValue) {
|
||||
maxValue = value
|
||||
} else if (v > maxValue) {
|
||||
maxValue = v
|
||||
}
|
||||
i += 1
|
||||
i += step
|
||||
}
|
||||
|
||||
var sum = 0.0
|
||||
i = 0
|
||||
while (i < values.length) {
|
||||
val exp = math.exp(values(i) - maxValue)
|
||||
values(i) = exp
|
||||
i = offset
|
||||
while (i < end) {
|
||||
val exp = math.exp(input(i) - maxValue)
|
||||
output(i) = exp
|
||||
sum += exp
|
||||
i += 1
|
||||
i += step
|
||||
}
|
||||
|
||||
BLAS.javaBLAS.dscal(values.length, 1.0 / sum, values, 1)
|
||||
BLAS.javaBLAS.dscal(n, 1.0 / sum, output, offset, step)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@ import java.util.Random
|
|||
import breeze.linalg.{sum => Bsum, DenseMatrix => BDM, DenseVector => BDV}
|
||||
import breeze.numerics.{log => brzlog}
|
||||
|
||||
import org.apache.spark.ml.impl.Utils
|
||||
|
||||
/**
|
||||
* Trait for loss function
|
||||
*/
|
||||
|
@ -79,30 +81,10 @@ private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with
|
|||
val weights = new BDV[Double](0)
|
||||
|
||||
override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
|
||||
require(!data.isTranspose && !output.isTranspose)
|
||||
var j = 0
|
||||
// find max value to make sure later that exponent is computable
|
||||
while (j < data.cols) {
|
||||
var i = 0
|
||||
var max = Double.MinValue
|
||||
while (i < data.rows) {
|
||||
if (data(i, j) > max) {
|
||||
max = data(i, j)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
var sum = 0.0
|
||||
i = 0
|
||||
while (i < data.rows) {
|
||||
val res = math.exp(data(i, j) - max)
|
||||
output(i, j) = res
|
||||
sum += res
|
||||
i += 1
|
||||
}
|
||||
i = 0
|
||||
while (i < data.rows) {
|
||||
output(i, j) /= sum
|
||||
i += 1
|
||||
}
|
||||
Utils.softmax(data.data, data.rows, j * data.rows, 1, output.data)
|
||||
j += 1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.json4s.DefaultFormats
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.PredictorParams
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.impl.Utils
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
|
@ -527,19 +528,7 @@ class NaiveBayesModel private[ml] (
|
|||
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
rawPrediction match {
|
||||
case dv: DenseVector =>
|
||||
var i = 0
|
||||
val size = dv.size
|
||||
val maxLog = dv.values.max
|
||||
while (i < size) {
|
||||
dv.values(i) = math.exp(dv.values(i) - maxLog)
|
||||
i += 1
|
||||
}
|
||||
val probSum = dv.values.sum
|
||||
i = 0
|
||||
while (i < size) {
|
||||
dv.values(i) = dv.values(i) / probSum
|
||||
i += 1
|
||||
}
|
||||
Utils.softmax(dv.values)
|
||||
dv
|
||||
case sv: SparseVector =>
|
||||
throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
|
||||
|
|
|
@ -74,7 +74,7 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
new DenseMatrix(numClasses, numFeatures, coefficientsArray)
|
||||
}
|
||||
|
||||
private lazy val intercept = if (fitIntercept) {
|
||||
@transient private lazy val intercept = if (fitIntercept) {
|
||||
new DenseVector(coefficientsArray.takeRight(numClasses))
|
||||
} else {
|
||||
null
|
||||
|
@ -83,8 +83,8 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
// pre-computed margin of an empty vector.
|
||||
// with this variable as an offset, for a sparse vector, we only need to
|
||||
// deal with non-zero values in prediction.
|
||||
private val marginOffset = if (fitWithMean) {
|
||||
val offset = intercept.copy
|
||||
@transient private lazy val marginOffset = if (fitWithMean) {
|
||||
val offset = new DenseVector(coefficientsArray.takeRight(numClasses)) // intercept
|
||||
BLAS.gemv(-1.0, linear, Vectors.dense(bcScaledMean.value), 1.0, offset)
|
||||
offset
|
||||
} else {
|
||||
|
@ -115,7 +115,7 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
val offset = if (fitWithMean) marginOffset else intercept
|
||||
var j = 0
|
||||
while (j < numClasses) {
|
||||
java.util.Arrays.fill(arr, j * size, (j + 1) * size, offset(j))
|
||||
if (offset(j) != 0) java.util.Arrays.fill(arr, j * size, (j + 1) * size, offset(j))
|
||||
j += 1
|
||||
}
|
||||
}
|
||||
|
@ -126,28 +126,17 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
var localLossSum = 0.0
|
||||
var localWeightSum = 0.0
|
||||
var i = 0
|
||||
val probs = Array.ofDim[Double](numClasses)
|
||||
val multiplierSum = Array.ofDim[Double](numClasses)
|
||||
while (i < size) {
|
||||
val weight = block.getWeight(i)
|
||||
localWeightSum += weight
|
||||
if (weight > 0) {
|
||||
val label = block.getLabel(i)
|
||||
var j = 0
|
||||
while (j < numClasses) { probs(j) = mat(i, j); j += 1 }
|
||||
Utils.softmax(probs)
|
||||
|
||||
j = 0
|
||||
while (j < numClasses) {
|
||||
val multiplier = weight * (probs(j) - (if (label == j) 1.0 else 0.0))
|
||||
mat.update(i, j, multiplier)
|
||||
multiplierSum(j) += multiplier
|
||||
j += 1
|
||||
}
|
||||
|
||||
localLossSum -= weight * math.log(probs(label.toInt))
|
||||
val labelIndex = i + block.getLabel(i).toInt * size
|
||||
Utils.softmax(arr, numClasses, i, size, arr) // prob
|
||||
localLossSum -= weight * math.log(arr(labelIndex))
|
||||
if (weight != 1) BLAS.javaBLAS.dscal(numClasses, weight, arr, i, size)
|
||||
arr(labelIndex) -= weight
|
||||
} else {
|
||||
var j = 0; while (j < numClasses) { mat.update(i, j, 0.0); j += 1 }
|
||||
BLAS.javaBLAS.dscal(numClasses, 0, arr, i, size)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
@ -173,6 +162,16 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
linearGradSumMat.foreachActive { (i, j, v) => gradientSumArray(i * numClasses + j) += v }
|
||||
}
|
||||
|
||||
if (fitIntercept) {
|
||||
val multiplierSum = Array.ofDim[Double](numClasses)
|
||||
var j = 0
|
||||
while (j < numClasses) {
|
||||
var i = j * size
|
||||
val end = i + size
|
||||
while (i < end) { multiplierSum(j) += arr(i); i += 1 }
|
||||
j += 1
|
||||
}
|
||||
|
||||
if (fitWithMean) {
|
||||
// above update of the linear part of gradientSumArray does NOT take the centering
|
||||
// into account, here we need to adjust this part.
|
||||
|
@ -182,7 +181,6 @@ private[ml] class MultinomialLogisticBlockAggregator(
|
|||
bcScaledMean.value, 1, gradientSumArray, numClasses)
|
||||
}
|
||||
|
||||
if (fitIntercept) {
|
||||
BLAS.javaBLAS.daxpy(numClasses, 1.0, multiplierSum, 0, 1,
|
||||
gradientSumArray, numClasses * numFeatures, 1)
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ class MultilayerPerceptronClassifierTest(SparkSessionTestCase):
|
|||
result = model.transform(test).head()
|
||||
expected_prediction = 2.0
|
||||
expected_probability = [0.0, 0.0, 1.0]
|
||||
expected_rawPrediction = [-11.6081922998, -8.15827998691, 22.17757045]
|
||||
expected_rawPrediction = [-11.824, -8.298, 22.5299]
|
||||
self.assertTrue(result.prediction, expected_prediction)
|
||||
self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4))
|
||||
self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, rtol=0.1))
|
||||
|
|
Loading…
Reference in a new issue