[SPARK-9834] [MLLIB] implement weighted least squares via normal equation
The goal of this PR is to have a weighted least squares implementation that takes the normal equation approach, and hence to be able to provide R-like summary statistics and support IRLS (used by GLMs). The tests match R's lm and glmnet. There are couple TODOs that can be addressed in future PRs: * consolidate summary statistics aggregators * move `dspr` to `BLAS` * etc It would be nice to have this merged first because it blocks couple other features. dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #8588 from mengxr/SPARK-9834.
This commit is contained in:
parent
820913f554
commit
52fe32f6ac
|
@ -0,0 +1,296 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.optim
|
||||
|
||||
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
|
||||
import org.netlib.util.intW
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.mllib.linalg._
|
||||
import org.apache.spark.mllib.linalg.distributed.RowMatrix
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
* Model fitted by [[WeightedLeastSquares]].
|
||||
* @param coefficients model coefficients
|
||||
* @param intercept model intercept
|
||||
*/
|
||||
private[ml] class WeightedLeastSquaresModel(
|
||||
val coefficients: DenseVector,
|
||||
val intercept: Double) extends Serializable
|
||||
|
||||
/**
|
||||
* Weighted least squares solver via normal equation.
|
||||
* Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares
|
||||
* formulation:
|
||||
*
|
||||
* min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i
|
||||
* + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^,
|
||||
*
|
||||
* where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by
|
||||
* [[standardizeLabel]] and [[standardizeFeatures]], respectively.
|
||||
*
|
||||
* Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to
|
||||
* match R's `lm`.
|
||||
* Turn on [[standardizeLabel]] to match R's `glmnet`.
|
||||
*
|
||||
* @param fitIntercept whether to fit intercept. If false, z is 0.0.
|
||||
* @param regParam L2 regularization parameter (lambda)
|
||||
* @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the
|
||||
* population standard deviation of the j-th column of A. Otherwise,
|
||||
* sigma,,j,, is 1.0.
|
||||
* @param standardizeLabel whether to standardize label. If true, delta is the population standard
|
||||
* deviation of the label column b. Otherwise, delta is 1.0.
|
||||
*/
|
||||
private[ml] class WeightedLeastSquares(
|
||||
val fitIntercept: Boolean,
|
||||
val regParam: Double,
|
||||
val standardizeFeatures: Boolean,
|
||||
val standardizeLabel: Boolean) extends Logging with Serializable {
|
||||
import WeightedLeastSquares._
|
||||
|
||||
require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
|
||||
if (regParam == 0.0) {
|
||||
logWarning("regParam is zero, which might cause numerical instability and overfitting.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
|
||||
*/
|
||||
def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
|
||||
val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
|
||||
summary.validate()
|
||||
logInfo(s"Number of instances: ${summary.count}.")
|
||||
val triK = summary.triK
|
||||
val bBar = summary.bBar
|
||||
val bStd = summary.bStd
|
||||
val aBar = summary.aBar
|
||||
val aVar = summary.aVar
|
||||
val abBar = summary.abBar
|
||||
val aaBar = summary.aaBar
|
||||
val aaValues = aaBar.values
|
||||
|
||||
if (fitIntercept) {
|
||||
// shift centers
|
||||
// A^T A - aBar aBar^T
|
||||
RowMatrix.dspr(-1.0, aBar, aaValues)
|
||||
// A^T b - bBar aBar
|
||||
BLAS.axpy(-bBar, aBar, abBar)
|
||||
}
|
||||
|
||||
// add regularization to diagonals
|
||||
var i = 0
|
||||
var j = 2
|
||||
while (i < triK) {
|
||||
var lambda = regParam
|
||||
if (standardizeFeatures) {
|
||||
lambda *= aVar(j - 2)
|
||||
}
|
||||
if (standardizeLabel) {
|
||||
// TODO: handle the case when bStd = 0
|
||||
lambda /= bStd
|
||||
}
|
||||
aaValues(i) += lambda
|
||||
i += j
|
||||
j += 1
|
||||
}
|
||||
|
||||
val x = choleskySolve(aaBar.values, abBar)
|
||||
|
||||
// compute intercept
|
||||
val intercept = if (fitIntercept) {
|
||||
bBar - BLAS.dot(aBar, x)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
|
||||
new WeightedLeastSquaresModel(x, intercept)
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves a symmetric positive definite linear system via Cholesky factorization.
|
||||
* The input arguments are modified in-place to store the factorization and the solution.
|
||||
* @param A the upper triangular part of A
|
||||
* @param bx right-hand side
|
||||
* @return the solution vector
|
||||
*/
|
||||
// TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS
|
||||
private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = {
|
||||
val k = bx.size
|
||||
val info = new intW(0)
|
||||
lapack.dppsv("U", k, 1, A, bx.values, k, info)
|
||||
val code = info.`val`
|
||||
assert(code == 0, s"lapack.dpotrs returned $code.")
|
||||
bx
|
||||
}
|
||||
}
|
||||
|
||||
private[ml] object WeightedLeastSquares {
|
||||
|
||||
/**
|
||||
* Case class for weighted observations.
|
||||
* @param w weight, must be positive
|
||||
* @param a features
|
||||
* @param b label
|
||||
*/
|
||||
case class Instance(w: Double, a: Vector, b: Double) {
|
||||
require(w >= 0.0, s"Weight cannot be negative: $w.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
|
||||
*/
|
||||
// TODO: consolidate aggregates for summary statistics
|
||||
private class Aggregator extends Serializable {
|
||||
var initialized: Boolean = false
|
||||
var k: Int = _
|
||||
var count: Long = _
|
||||
var triK: Int = _
|
||||
private var wSum: Double = _
|
||||
private var wwSum: Double = _
|
||||
private var bSum: Double = _
|
||||
private var bbSum: Double = _
|
||||
private var aSum: DenseVector = _
|
||||
private var abSum: DenseVector = _
|
||||
private var aaSum: DenseVector = _
|
||||
|
||||
private def init(k: Int): Unit = {
|
||||
require(k <= 4096, "In order to take the normal equation approach efficiently, " +
|
||||
s"we set the max number of features to 4096 but got $k.")
|
||||
this.k = k
|
||||
triK = k * (k + 1) / 2
|
||||
count = 0L
|
||||
wSum = 0.0
|
||||
wwSum = 0.0
|
||||
bSum = 0.0
|
||||
bbSum = 0.0
|
||||
aSum = new DenseVector(Array.ofDim(k))
|
||||
abSum = new DenseVector(Array.ofDim(k))
|
||||
aaSum = new DenseVector(Array.ofDim(triK))
|
||||
initialized = true
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an instance.
|
||||
*/
|
||||
def add(instance: Instance): this.type = {
|
||||
val Instance(w, a, b) = instance
|
||||
val ak = a.size
|
||||
if (!initialized) {
|
||||
init(ak)
|
||||
initialized = true
|
||||
}
|
||||
assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
|
||||
count += 1L
|
||||
wSum += w
|
||||
wwSum += w * w
|
||||
bSum += w * b
|
||||
bbSum += w * b * b
|
||||
BLAS.axpy(w, a, aSum)
|
||||
BLAS.axpy(w * b, a, abSum)
|
||||
RowMatrix.dspr(w, a, aaSum.values)
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges another [[Aggregator]].
|
||||
*/
|
||||
def merge(other: Aggregator): this.type = {
|
||||
if (!other.initialized) {
|
||||
this
|
||||
} else {
|
||||
if (!initialized) {
|
||||
init(other.k)
|
||||
}
|
||||
assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
|
||||
count += other.count
|
||||
wSum += other.wSum
|
||||
wwSum += other.wwSum
|
||||
bSum += other.bSum
|
||||
bbSum += other.bbSum
|
||||
BLAS.axpy(1.0, other.aSum, aSum)
|
||||
BLAS.axpy(1.0, other.abSum, abSum)
|
||||
BLAS.axpy(1.0, other.aaSum, aaSum)
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that we have seen observations.
|
||||
*/
|
||||
def validate(): Unit = {
|
||||
assert(initialized, "Training dataset is empty.")
|
||||
assert(wSum > 0.0, "Sum of weights cannot be zero.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Weighted mean of features.
|
||||
*/
|
||||
def aBar: DenseVector = {
|
||||
val output = aSum.copy
|
||||
BLAS.scal(1.0 / wSum, output)
|
||||
output
|
||||
}
|
||||
|
||||
/**
|
||||
* Weighted mean of labels.
|
||||
*/
|
||||
def bBar: Double = bSum / wSum
|
||||
|
||||
/**
|
||||
* Weighted population standard deviation of labels.
|
||||
*/
|
||||
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
|
||||
|
||||
/**
|
||||
* Weighted mean of (label * features).
|
||||
*/
|
||||
def abBar: DenseVector = {
|
||||
val output = abSum.copy
|
||||
BLAS.scal(1.0 / wSum, output)
|
||||
output
|
||||
}
|
||||
|
||||
/**
|
||||
* Weighted mean of (features * features^T^).
|
||||
*/
|
||||
def aaBar: DenseVector = {
|
||||
val output = aaSum.copy
|
||||
BLAS.scal(1.0 / wSum, output)
|
||||
output
|
||||
}
|
||||
|
||||
/**
|
||||
* Weighted population variance of features.
|
||||
*/
|
||||
def aVar: DenseVector = {
|
||||
val variance = Array.ofDim[Double](k)
|
||||
var i = 0
|
||||
var j = 2
|
||||
val aaValues = aaSum.values
|
||||
while (i < triK) {
|
||||
val l = j - 2
|
||||
val aw = aSum(l) / wSum
|
||||
variance(l) = aaValues(i) / wSum - aw * aw
|
||||
i += j
|
||||
j += 1
|
||||
}
|
||||
new DenseVector(variance)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/** Y += a * x */
|
||||
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
|
||||
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
|
||||
s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
|
||||
f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* dot(x, y)
|
||||
*/
|
||||
|
|
|
@ -678,7 +678,8 @@ object RowMatrix {
|
|||
*
|
||||
* @param U the upper triangular part of the matrix packed in an array (column major)
|
||||
*/
|
||||
private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
|
||||
// TODO: SPARK-10491 - move this method to linalg.BLAS
|
||||
private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
|
||||
// TODO: Find a better home (breeze?) for this method.
|
||||
val n = v.size
|
||||
v match {
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.optim
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
private var instances: RDD[Instance] = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
/*
|
||||
R code:
|
||||
|
||||
A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
|
||||
b <- c(17, 19, 23, 29)
|
||||
w <- c(1, 2, 3, 4)
|
||||
*/
|
||||
instances = sc.parallelize(Seq(
|
||||
Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
|
||||
Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
|
||||
Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
|
||||
Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
|
||||
), 2)
|
||||
}
|
||||
|
||||
test("WLS against lm") {
|
||||
/*
|
||||
R code:
|
||||
|
||||
df <- as.data.frame(cbind(A, b))
|
||||
for (formula in c(b ~ . -1, b ~ .)) {
|
||||
model <- lm(formula, data=df, weights=w)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] -3.727121 3.009983
|
||||
[1] 18.08 6.08 -0.60
|
||||
*/
|
||||
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, -3.727121, 3.009983),
|
||||
Vectors.dense(18.08, 6.08, -0.60))
|
||||
|
||||
var idx = 0
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val wls = new WeightedLeastSquares(
|
||||
fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
|
||||
.fit(instances)
|
||||
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
|
||||
assert(actual ~== expected(idx) absTol 1e-4)
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
|
||||
test("WLS against glmnet") {
|
||||
/*
|
||||
R code:
|
||||
|
||||
library(glmnet)
|
||||
|
||||
for (intercept in c(FALSE, TRUE)) {
|
||||
for (lambda in c(0.0, 0.1, 1.0)) {
|
||||
for (standardize in c(FALSE, TRUE)) {
|
||||
model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda,
|
||||
standardize=standardize, alpha=0, thresh=1E-14)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[1] 0.000000 -3.727117 3.009982
|
||||
[1] 0.000000 -3.727117 3.009982
|
||||
[1] 0.000000 -3.307532 2.924206
|
||||
[1] 0.000000 -2.914790 2.840627
|
||||
[1] 0.000000 -1.526575 2.558158
|
||||
[1] 0.00000000 0.06984238 2.20488344
|
||||
[1] 18.0799727 6.0799832 -0.5999941
|
||||
[1] 18.0799727 6.0799832 -0.5999941
|
||||
[1] 13.5356178 3.2714044 0.3770744
|
||||
[1] 14.064629 3.565802 0.269593
|
||||
[1] 10.1238013 0.9708569 1.1475466
|
||||
[1] 13.1860638 2.1761382 0.6213134
|
||||
*/
|
||||
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, -3.727117, 3.009982),
|
||||
Vectors.dense(0.0, -3.727117, 3.009982),
|
||||
Vectors.dense(0.0, -3.307532, 2.924206),
|
||||
Vectors.dense(0.0, -2.914790, 2.840627),
|
||||
Vectors.dense(0.0, -1.526575, 2.558158),
|
||||
Vectors.dense(0.0, 0.06984238, 2.20488344),
|
||||
Vectors.dense(18.0799727, 6.0799832, -0.5999941),
|
||||
Vectors.dense(18.0799727, 6.0799832, -0.5999941),
|
||||
Vectors.dense(13.5356178, 3.2714044, 0.3770744),
|
||||
Vectors.dense(14.064629, 3.565802, 0.269593),
|
||||
Vectors.dense(10.1238013, 0.9708569, 1.1475466),
|
||||
Vectors.dense(13.1860638, 2.1761382, 0.6213134))
|
||||
|
||||
var idx = 0
|
||||
for (fitIntercept <- Seq(false, true);
|
||||
regParam <- Seq(0.0, 0.1, 1.0);
|
||||
standardizeFeatures <- Seq(false, true)) {
|
||||
val wls = new WeightedLeastSquares(
|
||||
fitIntercept, regParam, standardizeFeatures, standardizeLabel = true)
|
||||
.fit(instances)
|
||||
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
|
||||
assert(actual ~== expected(idx) absTol 1e-4)
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue