[SPARK-9834] [MLLIB] implement weighted least squares via normal equation

The goal of this PR is to have a weighted least squares implementation that takes the normal equation approach, and hence to be able to provide R-like summary statistics and support IRLS (used by GLMs). The tests match R's lm and glmnet.

There are couple TODOs that can be addressed in future PRs:
* consolidate summary statistics aggregators
* move `dspr` to `BLAS`
* etc

It would be nice to have this merged first because it blocks couple other features.

dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes #8588 from mengxr/SPARK-9834.
This commit is contained in:
Xiangrui Meng 2015-09-08 20:51:20 -07:00
parent 820913f554
commit 52fe32f6ac
4 changed files with 438 additions and 1 deletions

View file

@ -0,0 +1,296 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.optim
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.netlib.util.intW
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.rdd.RDD
/**
* Model fitted by [[WeightedLeastSquares]].
* @param coefficients model coefficients
* @param intercept model intercept
*/
private[ml] class WeightedLeastSquaresModel(
val coefficients: DenseVector,
val intercept: Double) extends Serializable
/**
* Weighted least squares solver via normal equation.
* Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares
* formulation:
*
* min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i
* + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^,
*
* where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by
* [[standardizeLabel]] and [[standardizeFeatures]], respectively.
*
* Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to
* match R's `lm`.
* Turn on [[standardizeLabel]] to match R's `glmnet`.
*
* @param fitIntercept whether to fit intercept. If false, z is 0.0.
* @param regParam L2 regularization parameter (lambda)
* @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the
* population standard deviation of the j-th column of A. Otherwise,
* sigma,,j,, is 1.0.
* @param standardizeLabel whether to standardize label. If true, delta is the population standard
* deviation of the label column b. Otherwise, delta is 1.0.
*/
private[ml] class WeightedLeastSquares(
val fitIntercept: Boolean,
val regParam: Double,
val standardizeFeatures: Boolean,
val standardizeLabel: Boolean) extends Logging with Serializable {
import WeightedLeastSquares._
require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
if (regParam == 0.0) {
logWarning("regParam is zero, which might cause numerical instability and overfitting.")
}
/**
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
*/
def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
summary.validate()
logInfo(s"Number of instances: ${summary.count}.")
val triK = summary.triK
val bBar = summary.bBar
val bStd = summary.bStd
val aBar = summary.aBar
val aVar = summary.aVar
val abBar = summary.abBar
val aaBar = summary.aaBar
val aaValues = aaBar.values
if (fitIntercept) {
// shift centers
// A^T A - aBar aBar^T
RowMatrix.dspr(-1.0, aBar, aaValues)
// A^T b - bBar aBar
BLAS.axpy(-bBar, aBar, abBar)
}
// add regularization to diagonals
var i = 0
var j = 2
while (i < triK) {
var lambda = regParam
if (standardizeFeatures) {
lambda *= aVar(j - 2)
}
if (standardizeLabel) {
// TODO: handle the case when bStd = 0
lambda /= bStd
}
aaValues(i) += lambda
i += j
j += 1
}
val x = choleskySolve(aaBar.values, abBar)
// compute intercept
val intercept = if (fitIntercept) {
bBar - BLAS.dot(aBar, x)
} else {
0.0
}
new WeightedLeastSquaresModel(x, intercept)
}
/**
* Solves a symmetric positive definite linear system via Cholesky factorization.
* The input arguments are modified in-place to store the factorization and the solution.
* @param A the upper triangular part of A
* @param bx right-hand side
* @return the solution vector
*/
// TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS
private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = {
val k = bx.size
val info = new intW(0)
lapack.dppsv("U", k, 1, A, bx.values, k, info)
val code = info.`val`
assert(code == 0, s"lapack.dpotrs returned $code.")
bx
}
}
private[ml] object WeightedLeastSquares {
/**
* Case class for weighted observations.
* @param w weight, must be positive
* @param a features
* @param b label
*/
case class Instance(w: Double, a: Vector, b: Double) {
require(w >= 0.0, s"Weight cannot be negative: $w.")
}
/**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
// TODO: consolidate aggregates for summary statistics
private class Aggregator extends Serializable {
var initialized: Boolean = false
var k: Int = _
var count: Long = _
var triK: Int = _
private var wSum: Double = _
private var wwSum: Double = _
private var bSum: Double = _
private var bbSum: Double = _
private var aSum: DenseVector = _
private var abSum: DenseVector = _
private var aaSum: DenseVector = _
private def init(k: Int): Unit = {
require(k <= 4096, "In order to take the normal equation approach efficiently, " +
s"we set the max number of features to 4096 but got $k.")
this.k = k
triK = k * (k + 1) / 2
count = 0L
wSum = 0.0
wwSum = 0.0
bSum = 0.0
bbSum = 0.0
aSum = new DenseVector(Array.ofDim(k))
abSum = new DenseVector(Array.ofDim(k))
aaSum = new DenseVector(Array.ofDim(triK))
initialized = true
}
/**
* Adds an instance.
*/
def add(instance: Instance): this.type = {
val Instance(w, a, b) = instance
val ak = a.size
if (!initialized) {
init(ak)
initialized = true
}
assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
count += 1L
wSum += w
wwSum += w * w
bSum += w * b
bbSum += w * b * b
BLAS.axpy(w, a, aSum)
BLAS.axpy(w * b, a, abSum)
RowMatrix.dspr(w, a, aaSum.values)
this
}
/**
* Merges another [[Aggregator]].
*/
def merge(other: Aggregator): this.type = {
if (!other.initialized) {
this
} else {
if (!initialized) {
init(other.k)
}
assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
count += other.count
wSum += other.wSum
wwSum += other.wwSum
bSum += other.bSum
bbSum += other.bbSum
BLAS.axpy(1.0, other.aSum, aSum)
BLAS.axpy(1.0, other.abSum, abSum)
BLAS.axpy(1.0, other.aaSum, aaSum)
this
}
}
/**
* Validates that we have seen observations.
*/
def validate(): Unit = {
assert(initialized, "Training dataset is empty.")
assert(wSum > 0.0, "Sum of weights cannot be zero.")
}
/**
* Weighted mean of features.
*/
def aBar: DenseVector = {
val output = aSum.copy
BLAS.scal(1.0 / wSum, output)
output
}
/**
* Weighted mean of labels.
*/
def bBar: Double = bSum / wSum
/**
* Weighted population standard deviation of labels.
*/
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
/**
* Weighted mean of (label * features).
*/
def abBar: DenseVector = {
val output = abSum.copy
BLAS.scal(1.0 / wSum, output)
output
}
/**
* Weighted mean of (features * features^T^).
*/
def aaBar: DenseVector = {
val output = aaSum.copy
BLAS.scal(1.0 / wSum, output)
output
}
/**
* Weighted population variance of features.
*/
def aVar: DenseVector = {
val variance = Array.ofDim[Double](k)
var i = 0
var j = 2
val aaValues = aaSum.values
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
variance(l) = aaValues(i) / wSum - aw * aw
i += j
j += 1
}
new DenseVector(variance)
}
}
}

View file

@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
/** Y += a * x */
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
}
/**
* dot(x, y)
*/

View file

@ -678,7 +678,8 @@ object RowMatrix {
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: SPARK-10491 - move this method to linalg.BLAS
private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: Find a better home (breeze?) for this method.
val n = v.size
v match {

View file

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.optim
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
private var instances: RDD[Instance] = _
override def beforeAll(): Unit = {
super.beforeAll()
/*
R code:
A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
b <- c(17, 19, 23, 29)
w <- c(1, 2, 3, 4)
*/
instances = sc.parallelize(Seq(
Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
), 2)
}
test("WLS against lm") {
/*
R code:
df <- as.data.frame(cbind(A, b))
for (formula in c(b ~ . -1, b ~ .)) {
model <- lm(formula, data=df, weights=w)
print(as.vector(coef(model)))
}
[1] -3.727121 3.009983
[1] 18.08 6.08 -0.60
*/
val expected = Seq(
Vectors.dense(0.0, -3.727121, 3.009983),
Vectors.dense(18.08, 6.08, -0.60))
var idx = 0
for (fitIntercept <- Seq(false, true)) {
val wls = new WeightedLeastSquares(
fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
.fit(instances)
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
assert(actual ~== expected(idx) absTol 1e-4)
idx += 1
}
}
test("WLS against glmnet") {
/*
R code:
library(glmnet)
for (intercept in c(FALSE, TRUE)) {
for (lambda in c(0.0, 0.1, 1.0)) {
for (standardize in c(FALSE, TRUE)) {
model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda,
standardize=standardize, alpha=0, thresh=1E-14)
print(as.vector(coef(model)))
}
}
}
[1] 0.000000 -3.727117 3.009982
[1] 0.000000 -3.727117 3.009982
[1] 0.000000 -3.307532 2.924206
[1] 0.000000 -2.914790 2.840627
[1] 0.000000 -1.526575 2.558158
[1] 0.00000000 0.06984238 2.20488344
[1] 18.0799727 6.0799832 -0.5999941
[1] 18.0799727 6.0799832 -0.5999941
[1] 13.5356178 3.2714044 0.3770744
[1] 14.064629 3.565802 0.269593
[1] 10.1238013 0.9708569 1.1475466
[1] 13.1860638 2.1761382 0.6213134
*/
val expected = Seq(
Vectors.dense(0.0, -3.727117, 3.009982),
Vectors.dense(0.0, -3.727117, 3.009982),
Vectors.dense(0.0, -3.307532, 2.924206),
Vectors.dense(0.0, -2.914790, 2.840627),
Vectors.dense(0.0, -1.526575, 2.558158),
Vectors.dense(0.0, 0.06984238, 2.20488344),
Vectors.dense(18.0799727, 6.0799832, -0.5999941),
Vectors.dense(18.0799727, 6.0799832, -0.5999941),
Vectors.dense(13.5356178, 3.2714044, 0.3770744),
Vectors.dense(14.064629, 3.565802, 0.269593),
Vectors.dense(10.1238013, 0.9708569, 1.1475466),
Vectors.dense(13.1860638, 2.1761382, 0.6213134))
var idx = 0
for (fitIntercept <- Seq(false, true);
regParam <- Seq(0.0, 0.1, 1.0);
standardizeFeatures <- Seq(false, true)) {
val wls = new WeightedLeastSquares(
fitIntercept, regParam, standardizeFeatures, standardizeLabel = true)
.fit(instances)
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
assert(actual ~== expected(idx) absTol 1e-4)
idx += 1
}
}
}