[SPARK-34860][ML] Multinomial Logistic Regression with intercept support centering

### What changes were proposed in this pull request?
1, use new `MultinomialLogisticBlockAggregator` which support virtual centering
2, remove no-used `BlockLogisticAggregator`

### Why are the changes needed?
1, for better convergence;
2, its solution is much close to GLMNET;

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
updated and new test suites

Closes #31985 from zhengruifeng/mlr_center.

Authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Ruifeng Zheng 2021-03-30 18:06:59 -05:00 committed by Sean Owen
parent c902f77b42
commit d372e6e094
3 changed files with 178 additions and 301 deletions

View file

@ -941,19 +941,21 @@ class LogisticRegression @Since("1.2.0") (
optimizer: FirstOrderMinimizer[BDV[Double], DiffFunction[BDV[Double]]]) = {
val multinomial = checkMultinomial(numClasses)
// for binary LR, we can center the input vector, if and only if:
// for LR, we can center the input vector, if and only if:
// 1, fitIntercept is true;
// 2, no penalty on the intercept, which is always true in existing impl;
// 3, no bounds on the intercept.
val fitWithMean = !multinomial && $(fitIntercept) &&
(!isSet(lowerBoundsOnIntercepts) || $(lowerBoundsOnIntercepts)(0).isNegInfinity) &&
(!isSet(upperBoundsOnIntercepts) || $(upperBoundsOnIntercepts)(0).isPosInfinity)
val fitWithMean = $(fitIntercept) &&
(!isSet(lowerBoundsOnIntercepts) ||
$(lowerBoundsOnIntercepts).toArray.forall(_.isNegInfinity)) &&
(!isSet(upperBoundsOnIntercepts) ||
$(upperBoundsOnIntercepts).toArray.forall(_.isPosInfinity))
val numFeatures = featuresStd.length
val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0)
val scaledMean = Array.tabulate(numFeatures)(i => inverseStd(i) * featuresMean(i))
val bcInverseStd = instances.context.broadcast(inverseStd)
var bcObjects = Seq(bcInverseStd)
val bcScaledMean = instances.context.broadcast(scaledMean)
val scaled = instances.mapPartitions { iter =>
val func = StandardScalerModel.getTransformFunc(Array.empty, bcInverseStd.value, false, true)
@ -966,25 +968,30 @@ class LogisticRegression @Since("1.2.0") (
.setName(s"$uid: training blocks (blockSizeInMB=$actualBlockSizeInMB)")
val costFun = if (multinomial) {
// TODO: create a separate MultinomialLogisticBlockAggregator for clearness
val getAggregatorFunc = new BlockLogisticAggregator(numFeatures, numClasses,
$(fitIntercept), true)(_)
val getAggregatorFunc = new MultinomialLogisticBlockAggregator(bcInverseStd, bcScaledMean,
$(fitIntercept), fitWithMean)(_)
new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth))
} else {
val bcScaledMean = instances.context.broadcast(scaledMean)
bcObjects +:= bcScaledMean
val getAggregatorFunc = new BinaryLogisticBlockAggregator(bcInverseStd, bcScaledMean,
$(fitIntercept), fitWithMean)(_)
new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth))
}
if (fitWithMean) {
// orginal `initialCoefWithInterceptArray` is for problem:
// y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
// we should adjust it to the initial solution for problem:
// y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, initialSolution, 1, scaledMean, 1)
initialSolution(numFeatures) += adapt
if (multinomial) {
val adapt = Array.ofDim[Double](numClasses)
BLAS.f2jBLAS.dgemv("N", numClasses, numFeatures, 1.0,
initialSolution, numClasses, scaledMean, 1, 0.0, adapt, 1)
BLAS.getBLAS(numFeatures).daxpy(numClasses, 1.0, adapt, 0, 1,
initialSolution, numClasses * numFeatures, 1)
} else {
// orginal `initialCoefWithInterceptArray` is for problem:
// y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
// we should adjust it to the initial solution for problem:
// y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, initialSolution, 1, scaledMean, 1)
initialSolution(numFeatures) += adapt
}
}
val states = optimizer.iterations(new CachedDiffFunction(costFun),
@ -1002,16 +1009,25 @@ class LogisticRegression @Since("1.2.0") (
arrayBuilder += state.adjustedValue
}
blocks.unpersist()
bcObjects.foreach(_.destroy())
bcInverseStd.destroy()
bcScaledMean.destroy()
val solution = if (state == null) null else state.x.toArray
if (fitWithMean && solution != null) {
// the final solution is for problem:
// y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
// we should adjust it back for original problem:
// y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, solution, 1, scaledMean, 1)
solution(numFeatures) -= adapt
if (multinomial) {
val adapt = Array.ofDim[Double](numClasses)
BLAS.f2jBLAS.dgemv("N", numClasses, numFeatures, 1.0,
solution, numClasses, scaledMean, 1, 0.0, adapt, 1)
BLAS.getBLAS(numFeatures).daxpy(numClasses, -1.0, adapt, 0, 1,
solution, numClasses * numFeatures, 1)
} else {
// the final solution is for problem:
// y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
// we should adjust it back for original problem:
// y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, solution, 1, scaledMean, 1)
solution(numFeatures) -= adapt
}
}
(solution, arrayBuilder.result)
}

View file

@ -1,264 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.optim.aggregator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.InstanceBlock
import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.linalg._
/**
* BlockLogisticAggregator computes the gradient and loss used in Logistic classification
* for blocks in sparse or dense matrix in an online fashion.
*
* Two BlockLogisticAggregators can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
* NOTE: The feature values are expected to be standardized before computation.
*
* @param bcCoefficients The coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
*/
private[ml] class BlockLogisticAggregator(
numFeatures: Int,
numClasses: Int,
fitIntercept: Boolean,
multinomial: Boolean)(bcCoefficients: Broadcast[Vector])
extends DifferentiableLossAggregator[InstanceBlock, BlockLogisticAggregator] with Logging {
if (multinomial && numClasses <= 2) {
logInfo(s"Multinomial logistic regression for binary classification yields separate " +
s"coefficients for positive and negative classes. When no regularization is applied, the" +
s"result will be effectively the same as binary logistic regression. When regularization" +
s"is applied, multinomial loss will produce a result different from binary loss.")
}
private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
private val coefficientSize = bcCoefficients.value.size
protected override val dim: Int = coefficientSize
if (multinomial) {
require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " +
s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
} else {
require(coefficientSize == numFeaturesPlusIntercept, s"Expected $numFeaturesPlusIntercept " +
s"coefficients but got $coefficientSize")
require(numClasses == 1 || numClasses == 2, s"Binary logistic aggregator requires numClasses " +
s"in {1, 2} but found $numClasses.")
}
@transient private lazy val coefficientsArray = bcCoefficients.value match {
case DenseVector(values) => values
case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " +
s"got type ${bcCoefficients.value.getClass}.)")
}
@transient private lazy val binaryLinear = (multinomial, fitIntercept) match {
case (false, true) => Vectors.dense(coefficientsArray.take(numFeatures))
case (false, false) => Vectors.dense(coefficientsArray)
case _ => null
}
@transient private lazy val multinomialLinear = (multinomial, fitIntercept) match {
case (true, true) =>
Matrices.dense(numClasses, numFeatures,
coefficientsArray.take(numClasses * numFeatures)).toDense
case (true, false) =>
Matrices.dense(numClasses, numFeatures, coefficientsArray).toDense
case _ => null
}
/**
* Add a new training instance block to this BlockLogisticAggregator, and update the loss and
* gradient of the objective function.
*
* @param block The instance block of data point to be added.
* @return This BlockLogisticAggregator object.
*/
def add(block: InstanceBlock): this.type = {
require(block.matrix.isTransposed)
require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
require(block.weightIter.forall(_ >= 0),
s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0")
if (block.weightIter.forall(_ == 0)) return this
if (multinomial) {
multinomialUpdateInPlace(block)
} else {
binaryUpdateInPlace(block)
}
this
}
/** Update gradient and loss using binary loss function. */
private def binaryUpdateInPlace(block: InstanceBlock): Unit = {
val size = block.size
// vec here represents margins or negative dotProducts
val vec = if (fitIntercept) {
Vectors.dense(Array.fill(size)(coefficientsArray.last)).toDense
} else {
Vectors.zeros(size).toDense
}
BLAS.gemv(-1.0, block.matrix, binaryLinear, -1.0, vec)
// in-place convert margins to multiplier
// then, vec represents multiplier
var localLossSum = 0.0
var i = 0
while (i < size) {
val weight = block.getWeight(i)
if (weight > 0) {
val label = block.getLabel(i)
val margin = vec(i)
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
localLossSum += weight * Utils.log1pExp(margin)
} else {
localLossSum += weight * (Utils.log1pExp(margin) - margin)
}
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
vec.values(i) = multiplier
} else { vec.values(i) = 0.0 }
i += 1
}
lossSum += localLossSum
weightSum += block.weightIter.sum
// predictions are all correct, no gradient signal
if (vec.values.forall(_ == 0)) return
block.matrix match {
case dm: DenseMatrix =>
BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols,
vec.values, 1, 1.0, gradientSumArray, 1)
case sm: SparseMatrix if fitIntercept =>
val linearGradSumVec = Vectors.zeros(numFeatures).toDense
BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
gradientSumArray, 1)
case sm: SparseMatrix if !fitIntercept =>
val gradSumVec = new DenseVector(gradientSumArray)
BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec)
case m =>
throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.")
}
if (fitIntercept) gradientSumArray(numFeatures) += vec.values.sum
}
/** Update gradient and loss using multinomial (softmax) loss function. */
private def multinomialUpdateInPlace(block: InstanceBlock): Unit = {
val size = block.size
// mat here represents margins, shape: S X C
val mat = DenseMatrix.zeros(size, numClasses)
if (fitIntercept) {
val localCoefficientsArray = coefficientsArray
val offset = numClasses * numFeatures
var j = 0
while (j < numClasses) {
val intercept = localCoefficientsArray(offset + j)
var i = 0
while (i < size) { mat.update(i, j, intercept); i += 1 }
j += 1
}
}
BLAS.gemm(1.0, block.matrix, multinomialLinear.transpose, 1.0, mat)
// in-place convert margins to multipliers
// then, mat represents multipliers
var localLossSum = 0.0
var i = 0
val tmp = Array.ofDim[Double](numClasses)
val interceptGradSumArr = if (fitIntercept) Array.ofDim[Double](numClasses) else null
while (i < size) {
val weight = block.getWeight(i)
if (weight > 0) {
val label = block.getLabel(i)
var maxMargin = Double.NegativeInfinity
var j = 0
while (j < numClasses) {
tmp(j) = mat(i, j)
maxMargin = math.max(maxMargin, tmp(j))
j += 1
}
// marginOfLabel is margins(label) in the formula
val marginOfLabel = tmp(label.toInt)
var sum = 0.0
j = 0
while (j < numClasses) {
if (maxMargin > 0) tmp(j) -= maxMargin
val exp = math.exp(tmp(j))
sum += exp
tmp(j) = exp
j += 1
}
j = 0
while (j < numClasses) {
val multiplier = weight * (tmp(j) / sum - (if (label == j) 1.0 else 0.0))
mat.update(i, j, multiplier)
if (fitIntercept) interceptGradSumArr(j) += multiplier
j += 1
}
if (maxMargin > 0) {
localLossSum += weight * (math.log(sum) - marginOfLabel + maxMargin)
} else {
localLossSum += weight * (math.log(sum) - marginOfLabel)
}
} else {
var j = 0; while (j < numClasses) { mat.update(i, j, 0.0); j += 1 }
}
i += 1
}
lossSum += localLossSum
weightSum += block.weightIter.sum
// mat (multipliers): S X C, dense N
// mat.transpose (multipliers): C X S, dense T
// block.matrix: S X F, unknown type T
// gradSumMat(gradientSumArray): C X FPI (numFeaturesPlusIntercept), dense N
block.matrix match {
case dm: DenseMatrix =>
BLAS.nativeBLAS.dgemm("T", "T", numClasses, numFeatures, size, 1.0,
mat.values, size, dm.values, numFeatures, 1.0, gradientSumArray, numClasses)
case sm: SparseMatrix =>
// linearGradSumMat = matrix.T X mat
val linearGradSumMat = DenseMatrix.zeros(numFeatures, numClasses)
BLAS.gemm(1.0, sm.transpose, mat, 0.0, linearGradSumMat)
linearGradSumMat.foreachActive { (i, j, v) => gradientSumArray(i * numClasses + j) += v }
}
if (fitIntercept) {
BLAS.getBLAS(numClasses).daxpy(numClasses, 1.0, interceptGradSumArr, 0, 1,
gradientSumArray, numClasses * numFeatures, 1)
}
}
}

View file

@ -46,6 +46,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
@transient var binaryDataset: DataFrame = _
@transient var binaryDatasetWithSmallVar: DataFrame = _
@transient var multinomialDataset: DataFrame = _
@transient var multinomialDatasetWithSmallVar: DataFrame = _
@transient var multinomialDatasetWithZeroVar: DataFrame = _
private val eps: Double = 1e-5
@ -118,6 +119,23 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
df
}
multinomialDatasetWithSmallVar = {
val nPoints = 50000
val coefficients = Array(
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
val xMean = Array(5.843, 3.057, 3.758, 10.199)
val xVariance = Array(0.6856, 0.1899, 3.116, 0.001)
val testData = generateMultinomialLogisticInput(
coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
val df = sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed))
df.cache()
df
}
multinomialDatasetWithZeroVar = {
val nPoints = 100
val coefficients = Array(
@ -141,18 +159,21 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
* so we can validate the training accuracy compared with R's glmnet package.
*/
ignore("export test data into CSV format") {
binaryDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) =>
label + "," + weight + "," + features.toArray.mkString(",")
binaryDataset.rdd.map { case Row(l: Double, f: Vector, w: Double) =>
l + "," + w + "," + f.toArray.mkString(",")
}.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/binaryDataset")
binaryDatasetWithSmallVar.rdd.map { case Row(label: Double, features: Vector, weight: Double) =>
label + "," + weight + "," + features.toArray.mkString(",")
binaryDatasetWithSmallVar.rdd.map { case Row(l: Double, f: Vector, w: Double) =>
l + "," + w + "," + f.toArray.mkString(",")
}.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/binaryDatasetWithSmallVar")
multinomialDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) =>
label + "," + weight + "," + features.toArray.mkString(",")
multinomialDataset.rdd.map { case Row(l: Double, f: Vector, w: Double) =>
l + "," + w + "," + f.toArray.mkString(",")
}.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDataset")
multinomialDatasetWithZeroVar.rdd.map {
case Row(label: Double, features: Vector, weight: Double) =>
label + "," + weight + "," + features.toArray.mkString(",")
multinomialDatasetWithSmallVar.rdd.map { case Row(l: Double, f: Vector, w: Double) =>
l + "," + w + "," + f.toArray.mkString(",")
}.repartition(1)
.saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDatasetWithSmallVar")
multinomialDatasetWithZeroVar.rdd.map { case Row(l: Double, f: Vector, w: Double) =>
l + "," + w + "," + f.toArray.mkString(",")
}.repartition(1)
.saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDatasetWithZeroVar")
}
@ -1863,21 +1884,125 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
0.0, 0.0, 0.0, 0.09064661,
-0.1144333, 0.3204703, -0.1621061, -0.2308192,
0.0, -0.4832131, 0.0, 0.0), isTransposed = true)
val interceptsRStd = Vectors.dense(-0.72638218, -0.01737265, 0.74375484)
val interceptsRStd = Vectors.dense(-0.69265374, -0.2260274, 0.9186811)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.01641412, 0.03570376,
-0.05110822, 0.0, -0.21595670, -0.16162836,
0.0, 0.0, 0.0, 0.0), isTransposed = true)
val interceptsR = Vectors.dense(-0.44707756, 0.75180900, -0.3047314)
assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.05)
assert(model1.interceptVector ~== interceptsRStd relTol 0.1)
assert(model1.coefficientMatrix ~== coefficientsRStd absTol 1e-3)
assert(model1.interceptVector ~== interceptsRStd relTol 1e-3)
assert(model1.interceptVector.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficientMatrix ~== coefficientsR absTol 0.02)
assert(model2.interceptVector ~== interceptsR relTol 0.1)
assert(model2.coefficientMatrix ~== coefficientsR absTol 1e-3)
assert(model2.interceptVector ~== interceptsR relTol 1e-3)
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
}
test("SPARK-34860: multinomial logistic regression with intercept, with small var") {
val trainer1 = new LogisticRegression().setFitIntercept(true).setStandardization(true)
.setWeightCol("weight")
val trainer2 = new LogisticRegression().setFitIntercept(true).setStandardization(false)
.setWeightCol("weight")
val trainer3 = new LogisticRegression().setFitIntercept(true).setStandardization(true)
.setElasticNetParam(0.0001).setRegParam(0.5).setWeightCol("weight")
val model1 = trainer1.fit(multinomialDatasetWithSmallVar)
val model2 = trainer2.fit(multinomialDatasetWithSmallVar)
val model3 = trainer3.fit(multinomialDatasetWithSmallVar)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = factor(data$V1)
w = data$V2
features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6))
coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0,
lambda = 0))
coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
2.91748298
data.V3 0.21755977
data.V4 0.01647541
data.V5 0.16507778
data.V6 -0.14016680
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-17.5107460
data.V3 -0.2443600
data.V4 0.7564655
data.V5 -0.2955698
data.V6 1.3262009
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
14.59326301
data.V3 0.02680026
data.V4 -0.77294095
data.V5 0.13049206
data.V6 -1.18603411
coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial",
alpha = 0.0001, lambda = 0.5, standardize=T))
coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
1.751626027
data.V3 0.019970169
data.V4 0.079611293
data.V5 0.003959452
data.V6 0.110024399
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-3.9297124987
data.V3 -0.0004788494
data.V4 0.0010097453
data.V5 -0.0005832701
data.V6 .
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
2.178086472
data.V3 -0.019369990
data.V4 -0.080851149
data.V5 -0.003319687
data.V6 -0.112435972
*/
val interceptsR = Vectors.dense(2.91748298, -17.5107460, 14.59326301)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.21755977, 0.01647541, 0.16507778, -0.14016680,
-0.2443600, 0.7564655, -0.2955698, 1.3262009,
0.02680026, -0.77294095, 0.13049206, -1.18603411), isTransposed = true)
assert(model1.interceptVector ~== interceptsR relTol 1e-2)
assert(model1.coefficientMatrix ~= coefficientsR relTol 1e-1)
// Without regularization, with or without standardization will converge to the same solution.
assert(model2.interceptVector ~== interceptsR relTol 1e-2)
assert(model2.coefficientMatrix ~= coefficientsR relTol 1e-1)
val interceptsR2 = Vectors.dense(1.751626027, -3.9297124987, 2.178086472)
val coefficientsR2 = new DenseMatrix(3, 4, Array(
0.019970169, 0.079611293, 0.003959452, 0.110024399,
-0.0004788494, 0.0010097453, -0.0005832701, 0.0,
-0.019369990, -0.080851149, -0.003319687, -0.112435972), isTransposed = true)
assert(model3.interceptVector ~== interceptsR2 relTol 1e-3)
assert(model3.coefficientMatrix ~= coefficientsR2 relTol 1e-2)
}
test("multinomial logistic regression without intercept with L1 regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(false)
.setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true).setWeightCol("weight")