diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala new file mode 100644 index 0000000000..a99e2ac4c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[WeightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + */ +private[ml] class WeightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double) extends Serializable + +/** + * Weighted least squares solver via normal equation. + * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares + * formulation: + * + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i + * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^, + * + * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by + * [[standardizeLabel]] and [[standardizeFeatures]], respectively. + * + * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to + * match R's `lm`. + * Turn on [[standardizeLabel]] to match R's `glmnet`. + * + * @param fitIntercept whether to fit intercept. If false, z is 0.0. + * @param regParam L2 regularization parameter (lambda) + * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the + * population standard deviation of the j-th column of A. Otherwise, + * sigma,,j,, is 1.0. + * @param standardizeLabel whether to standardize label. If true, delta is the population standard + * deviation of the label column b. Otherwise, delta is 1.0. + */ +private[ml] class WeightedLeastSquares( + val fitIntercept: Boolean, + val regParam: Double, + val standardizeFeatures: Boolean, + val standardizeLabel: Boolean) extends Logging with Serializable { + import WeightedLeastSquares._ + + require(regParam >= 0.0, s"regParam cannot be negative: $regParam") + if (regParam == 0.0) { + logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + + /** + * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. + */ + def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) + summary.validate() + logInfo(s"Number of instances: ${summary.count}.") + val triK = summary.triK + val bBar = summary.bBar + val bStd = summary.bStd + val aBar = summary.aBar + val aVar = summary.aVar + val abBar = summary.abBar + val aaBar = summary.aaBar + val aaValues = aaBar.values + + if (fitIntercept) { + // shift centers + // A^T A - aBar aBar^T + RowMatrix.dspr(-1.0, aBar, aaValues) + // A^T b - bBar aBar + BLAS.axpy(-bBar, aBar, abBar) + } + + // add regularization to diagonals + var i = 0 + var j = 2 + while (i < triK) { + var lambda = regParam + if (standardizeFeatures) { + lambda *= aVar(j - 2) + } + if (standardizeLabel) { + // TODO: handle the case when bStd = 0 + lambda /= bStd + } + aaValues(i) += lambda + i += j + j += 1 + } + + val x = choleskySolve(aaBar.values, abBar) + + // compute intercept + val intercept = if (fitIntercept) { + bBar - BLAS.dot(aBar, x) + } else { + 0.0 + } + + new WeightedLeastSquaresModel(x, intercept) + } + + /** + * Solves a symmetric positive definite linear system via Cholesky factorization. + * The input arguments are modified in-place to store the factorization and the solution. + * @param A the upper triangular part of A + * @param bx right-hand side + * @return the solution vector + */ + // TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS + private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = { + val k = bx.size + val info = new intW(0) + lapack.dppsv("U", k, 1, A, bx.values, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dpotrs returned $code.") + bx + } +} + +private[ml] object WeightedLeastSquares { + + /** + * Case class for weighted observations. + * @param w weight, must be positive + * @param a features + * @param b label + */ + case class Instance(w: Double, a: Vector, b: Double) { + require(w >= 0.0, s"Weight cannot be negative: $w.") + } + + /** + * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. + */ + // TODO: consolidate aggregates for summary statistics + private class Aggregator extends Serializable { + var initialized: Boolean = false + var k: Int = _ + var count: Long = _ + var triK: Int = _ + private var wSum: Double = _ + private var wwSum: Double = _ + private var bSum: Double = _ + private var bbSum: Double = _ + private var aSum: DenseVector = _ + private var abSum: DenseVector = _ + private var aaSum: DenseVector = _ + + private def init(k: Int): Unit = { + require(k <= 4096, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to 4096 but got $k.") + this.k = k + triK = k * (k + 1) / 2 + count = 0L + wSum = 0.0 + wwSum = 0.0 + bSum = 0.0 + bbSum = 0.0 + aSum = new DenseVector(Array.ofDim(k)) + abSum = new DenseVector(Array.ofDim(k)) + aaSum = new DenseVector(Array.ofDim(triK)) + initialized = true + } + + /** + * Adds an instance. + */ + def add(instance: Instance): this.type = { + val Instance(w, a, b) = instance + val ak = a.size + if (!initialized) { + init(ak) + initialized = true + } + assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") + count += 1L + wSum += w + wwSum += w * w + bSum += w * b + bbSum += w * b * b + BLAS.axpy(w, a, aSum) + BLAS.axpy(w * b, a, abSum) + RowMatrix.dspr(w, a, aaSum.values) + this + } + + /** + * Merges another [[Aggregator]]. + */ + def merge(other: Aggregator): this.type = { + if (!other.initialized) { + this + } else { + if (!initialized) { + init(other.k) + } + assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}") + count += other.count + wSum += other.wSum + wwSum += other.wwSum + bSum += other.bSum + bbSum += other.bbSum + BLAS.axpy(1.0, other.aSum, aSum) + BLAS.axpy(1.0, other.abSum, abSum) + BLAS.axpy(1.0, other.aaSum, aaSum) + this + } + } + + /** + * Validates that we have seen observations. + */ + def validate(): Unit = { + assert(initialized, "Training dataset is empty.") + assert(wSum > 0.0, "Sum of weights cannot be zero.") + } + + /** + * Weighted mean of features. + */ + def aBar: DenseVector = { + val output = aSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of labels. + */ + def bBar: Double = bSum / wSum + + /** + * Weighted population standard deviation of labels. + */ + def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + + /** + * Weighted mean of (label * features). + */ + def abBar: DenseVector = { + val output = abSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of (features * features^T^). + */ + def aaBar: DenseVector = { + val output = aaSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted population variance of features. + */ + def aVar: DenseVector = { + val variance = Array.ofDim[Double](k) + var i = 0 + var j = 2 + val aaValues = aaSum.values + while (i < triK) { + val l = j - 2 + val aw = aSum(l) / wSum + variance(l) = aaValues(i) / wSum - aw * aw + i += j + j += 1 + } + new DenseVector(variance) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index ab475af264..9ee81eda8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging { } } + /** Y += a * x */ + private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = { + require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " + + s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.") + f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1) + } + /** * dot(x, y) */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a423ddafd..83779ac889 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -678,7 +678,8 @@ object RowMatrix { * * @param U the upper triangular part of the matrix packed in an array (column major) */ - private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + // TODO: SPARK-10491 - move this method to linalg.BLAS + private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { // TODO: Find a better home (breeze?) for this method. val n = v.size v match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala new file mode 100644 index 0000000000..652f3adb98 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.optim.WeightedLeastSquares.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + instances = sc.parallelize(Seq( + Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0), + Instance(2.0, Vectors.dense(1.0, 7.0), 19.0), + Instance(3.0, Vectors.dense(2.0, 11.0), 23.0), + Instance(4.0, Vectors.dense(3.0, 13.0), 29.0) + ), 2) + } + + test("WLS against lm") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- lm(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -3.727121 3.009983 + [1] 18.08 6.08 -0.60 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727121, 3.009983), + Vectors.dense(18.08, 6.08, -0.60)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("WLS against glmnet") { + /* + R code: + + library(glmnet) + + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + for (standardize in c(FALSE, TRUE)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda, + standardize=standardize, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.307532 2.924206 + [1] 0.000000 -2.914790 2.840627 + [1] 0.000000 -1.526575 2.558158 + [1] 0.00000000 0.06984238 2.20488344 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 13.5356178 3.2714044 0.3770744 + [1] 14.064629 3.565802 0.269593 + [1] 10.1238013 0.9708569 1.1475466 + [1] 13.1860638 2.1761382 0.6213134 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.307532, 2.924206), + Vectors.dense(0.0, -2.914790, 2.840627), + Vectors.dense(0.0, -1.526575, 2.558158), + Vectors.dense(0.0, 0.06984238, 2.20488344), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(13.5356178, 3.2714044, 0.3770744), + Vectors.dense(14.064629, 3.565802, 0.269593), + Vectors.dense(10.1238013, 0.9708569, 1.1475466), + Vectors.dense(13.1860638, 2.1761382, 0.6213134)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0); + standardizeFeatures <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam, standardizeFeatures, standardizeLabel = true) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +}