[SPARK-16934][ML][MLLIB] Update LogisticCostAggregator serialization code to make it consistent with LinearRegression
## What changes were proposed in this pull request? Update LogisticCostAggregator serialization code to make it consistent with #14109 ## How was this patch tested? MLlib 2.0: ![image](https://cloud.githubusercontent.com/assets/19235986/17649601/5e2a79ac-61ee-11e6-833c-3bd8b5250470.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/19235986/17649599/52b002ae-61ee-11e6-9402-9feb3439880f.png) Author: WeichenXu <WeichenXu123@outlook.com> Closes #14520 from WeichenXu123/improve_logistic_regression_costfun.
This commit is contained in:
parent
ddf0d1e3fe
commit
3d8bfe7a39
|
@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
|
|||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.linalg._
|
||||
|
@ -346,8 +347,9 @@ class LogisticRegression @Since("1.2.0") (
|
|||
val regParamL1 = $(elasticNetParam) * $(regParam)
|
||||
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
|
||||
|
||||
val bcFeaturesStd = instances.context.broadcast(featuresStd)
|
||||
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
|
||||
$(standardization), featuresStd, featuresMean, regParamL2)
|
||||
$(standardization), bcFeaturesStd, regParamL2)
|
||||
|
||||
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
|
||||
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
|
||||
|
@ -442,6 +444,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
|
||||
i += 1
|
||||
}
|
||||
bcFeaturesStd.destroy(blocking = false)
|
||||
|
||||
if ($(fitIntercept)) {
|
||||
(Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
|
||||
|
@ -938,11 +941,15 @@ class BinaryLogisticRegressionSummary private[classification] (
|
|||
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
|
||||
* the corresponding joint dataset.
|
||||
*
|
||||
* @param bcCoefficients The broadcast coefficients corresponding to the features.
|
||||
* @param bcFeaturesStd The broadcast standard deviation values of the features.
|
||||
* @param numClasses the number of possible outcomes for k classes classification problem in
|
||||
* Multinomial Logistic Regression.
|
||||
* @param fitIntercept Whether to fit an intercept term.
|
||||
*/
|
||||
private class LogisticAggregator(
|
||||
val bcCoefficients: Broadcast[Vector],
|
||||
val bcFeaturesStd: Broadcast[Array[Double]],
|
||||
private val numFeatures: Int,
|
||||
numClasses: Int,
|
||||
fitIntercept: Boolean) extends Serializable {
|
||||
|
@ -958,14 +965,9 @@ private class LogisticAggregator(
|
|||
* of the objective function.
|
||||
*
|
||||
* @param instance The instance of data point to be added.
|
||||
* @param coefficients The coefficients corresponding to the features.
|
||||
* @param featuresStd The standard deviation values of the features.
|
||||
* @return This LogisticAggregator object.
|
||||
*/
|
||||
def add(
|
||||
instance: Instance,
|
||||
coefficients: Vector,
|
||||
featuresStd: Array[Double]): this.type = {
|
||||
def add(instance: Instance): this.type = {
|
||||
instance match { case Instance(label, weight, features) =>
|
||||
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
|
||||
s" Expecting $numFeatures but got ${features.size}.")
|
||||
|
@ -973,14 +975,16 @@ private class LogisticAggregator(
|
|||
|
||||
if (weight == 0.0) return this
|
||||
|
||||
val coefficientsArray = coefficients match {
|
||||
val coefficientsArray = bcCoefficients.value match {
|
||||
case dv: DenseVector => dv.values
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
|
||||
"coefficients only supports dense vector" +
|
||||
s"but got type ${bcCoefficients.value.getClass}.")
|
||||
}
|
||||
val localGradientSumArray = gradientSumArray
|
||||
|
||||
val featuresStd = bcFeaturesStd.value
|
||||
numClasses match {
|
||||
case 2 =>
|
||||
// For Binary Logistic Regression.
|
||||
|
@ -1077,24 +1081,23 @@ private class LogisticCostFun(
|
|||
numClasses: Int,
|
||||
fitIntercept: Boolean,
|
||||
standardization: Boolean,
|
||||
featuresStd: Array[Double],
|
||||
featuresMean: Array[Double],
|
||||
bcFeaturesStd: Broadcast[Array[Double]],
|
||||
regParamL2: Double) extends DiffFunction[BDV[Double]] {
|
||||
|
||||
val featuresStd = bcFeaturesStd.value
|
||||
|
||||
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
|
||||
val numFeatures = featuresStd.length
|
||||
val coeffs = Vectors.fromBreeze(coefficients)
|
||||
val bcCoeffs = instances.context.broadcast(coeffs)
|
||||
val n = coeffs.size
|
||||
val localFeaturesStd = featuresStd
|
||||
|
||||
|
||||
val logisticAggregator = {
|
||||
val seqOp = (c: LogisticAggregator, instance: Instance) =>
|
||||
c.add(instance, coeffs, localFeaturesStd)
|
||||
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
|
||||
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
|
||||
|
||||
instances.treeAggregate(
|
||||
new LogisticAggregator(numFeatures, numClasses, fitIntercept)
|
||||
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
|
||||
)(seqOp, combOp)
|
||||
}
|
||||
|
||||
|
@ -1134,6 +1137,7 @@ private class LogisticCostFun(
|
|||
}
|
||||
0.5 * regParamL2 * sum
|
||||
}
|
||||
bcCoeffs.destroy(blocking = false)
|
||||
|
||||
(logisticAggregator.loss + regVal, new BDV(totalGradientArray))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue