Address Reynold's comments. Also use a builder pattern to construct the regression classes.

This commit is contained in:
Shivaram Venkataraman 2013-06-12 16:30:04 -07:00 committed by Matei Zaharia
parent 48770419bd
commit fd137bd7c6
7 changed files with 233 additions and 83 deletions

View file

@ -3,6 +3,13 @@ package spark.ml
import org.jblas.DoubleMatrix
abstract class Gradient extends Serializable {
/**
* Compute the gradient for a given row of data.
*
* @param data - One row of data. Row matrix of size 1xn where n is the number of features.
* @param label - Label for this data item.
* @param weights - Column matrix containing weights for every feature.
*/
def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
(DoubleMatrix, Double)
}
@ -14,11 +21,13 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
val gradient = data.mul(gradientMultiplier)
val loss = if (margin > 0) {
math.log(1 + math.exp(0 - margin))
} else {
math.log(1 + math.exp(margin)) - margin
}
val loss =
if (margin > 0) {
math.log(1 + math.exp(0 - margin))
} else {
math.log(1 + math.exp(margin)) - margin
}
(gradient, loss)
}
}

View file

@ -10,6 +10,20 @@ import scala.collection.mutable.ArrayBuffer
object GradientDescent {
/**
* Run gradient descent in parallel using mini batches.
*
* @param data - Input data for SGD. RDD of form (label, [feature values]).
* @param gradient - Gradient object that will be used to compute the gradient.
* @param updater - Updater object that will be used to update the model.
* @param stepSize - stepSize to be used during update.
* @param numIters - number of iterations that SGD should be run.
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
* @return weights - Column matrix containing weights for every feature.
* @return lossHistory - Array containing the loss computed for every iteration.
*/
def runMiniBatchSGD(
data: RDD[(Double, Array[Double])],
gradient: Gradient,
@ -18,22 +32,23 @@ object GradientDescent {
numIters: Int,
miniBatchFraction: Double=1.0) : (DoubleMatrix, Array[Double]) = {
val lossHistory = new ArrayBuffer[Double]
val lossHistory = new ArrayBuffer[Double](numIters)
val nfeatures: Int = data.take(1)(0)._2.length
val nexamples: Long = data.count()
val miniBatchSize = nexamples * miniBatchFraction
// Initialize weights as a column matrix
var weights = DoubleMatrix.ones(nfeatures)
var reg_val = 0.0
for (i <- 1 to numIters) {
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42).map {
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map {
case (y, features) =>
val featuresRow = new DoubleMatrix(features.length, 1, features:_*)
val (grad, loss) = gradient.compute(featuresRow, y, weights)
(grad, loss)
}.reduce((a, b) => (a._1.add(b._1), a._2 + b._2))
}.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
lossHistory.append(lossSum / miniBatchSize + reg_val)
val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i)

View file

@ -1,7 +1,6 @@
package spark.ml
import spark.{Logging, RDD, SparkContext}
import spark.SparkContext._
import org.jblas.DoubleMatrix
@ -14,9 +13,9 @@ class LogisticRegressionModel(
val losses: Array[Double]) extends RegressionModel(weights, intercept) {
override def predict(test_data: spark.RDD[Array[Double]]) = {
test_data.map { x =>
val margin = (new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
1.0/(1.0 + math.exp(margin * -1))
test_data.map { x =>
val margin = new DoubleMatrix(1, x.length, x:_*).mmul(this.weights).get(0) + this.intercept
1.0/ (1.0 + math.exp(margin * -1))
}
}
}
@ -25,9 +24,9 @@ class LogisticRegressionData(data: RDD[(Double, Array[Double])]) extends Regress
override def normalizeData() = {
// Shift only the features for LogisticRegression
data.map { case(y, features) =>
val featuresNormalized = (0 until nfeatures).map(
column => (features(column) - xColMean(column)) / xColSd(column)
).toArray
val featuresNormalized = Array.tabulate(nfeatures) { column =>
(features(column) - xColMean(column)) / xColSd(column)
}
(y, featuresNormalized)
}
}
@ -44,25 +43,41 @@ class LogisticRegressionData(data: RDD[(Double, Array[Double])]) extends Regress
}
}
object LogisticRegression extends Logging {
val STEP_SIZE = 1.0
val MINI_BATCH_FRACTION = 1.0
class LogisticRegression(stepSize: Double, miniBatchFraction: Double, numIters: Int)
extends Regression with Logging {
def train(input: RDD[(Double, Array[Double])], numIters: Int) = {
override def train(input: RDD[(Double, Array[Double])]): RegressionModel = {
input.cache()
val lrData = new LogisticRegressionData(input)
val data = lrData.normalizeData()
val (weights, losses) = GradientDescent.runMiniBatchSGD(
data, new LogisticGradient(), new SimpleUpdater(), STEP_SIZE, numIters, MINI_BATCH_FRACTION)
data, new LogisticGradient(), new SimpleUpdater(), stepSize, numIters, numIters)
val computedModel = new LogisticRegressionModel(weights, 0, losses)
val model = lrData.scaleModel(computedModel)
logInfo("Final model weights " + model.weights)
logInfo("Final model intercept " + model.intercept)
logInfo("Last 10 losses " + model.losses.takeRight(10).mkString(","))
logInfo("Last 10 losses " + model.losses.takeRight(10).mkString(", "))
model
}
}
/**
* Helper classes to build a LogisticRegression object.
*/
object LogisticRegression {
/**
* Build a logistic regression object with default arguments:
*
* @param stepSize as 1.0
* @param miniBatchFraction as 1.0
* @param numIters as 100
*/
def builder() = {
new LogisticRegressionBuilder(1.0, 1.0, 100)
}
def main(args: Array[String]) {
if (args.length != 3) {
@ -71,7 +86,42 @@ object LogisticRegression extends Logging {
}
val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadData(sc, args(1))
val model = train(data, args(2).toInt)
val lr = LogisticRegression.builder()
.setStepSize(2.0)
.setNumIterations(args(2).toInt)
.build()
val model = lr.train(data)
sc.stop()
}
}
class LogisticRegressionBuilder(stepSize: Double, miniBatchFraction: Double, numIters: Int) {
/**
* Set the step size per-iteration of SGD. Default 1.0.
*/
def setStepSize(step: Double) = {
new LogisticRegressionBuilder(step, this.miniBatchFraction, this.numIters)
}
/**
* Set fraction of data to be used for each SGD iteration. Default 1.0.
*/
def setMiniBatchFraction(fraction: Double) = {
new LogisticRegressionBuilder(this.stepSize, fraction, this.numIters)
}
/**
* Set the number of iterations for SGD. Default 100.
*/
def setNumIterations(iters: Int) = {
new LogisticRegressionBuilder(this.stepSize, this.miniBatchFraction, iters)
}
/**
* Build a Logistic regression object.
*/
def build() = {
new LogisticRegression(stepSize, miniBatchFraction, numIters)
}
}

View file

@ -2,21 +2,27 @@ package spark.ml
import spark.{RDD, SparkContext}
/**
* Helper methods to load and save data
* Data format:
* <l>, <f1> <f2> ...
* where <f1>, <f2> are feature values in Double and <l> is the corresponding label as Double.
*/
object MLUtils {
// Helper methods to load and save data
// Data format:
// <l>, <f1> <f2> ...
// where <f1>, <f2> are feature values in Double and
// <l> is the corresponding label as Double
def loadData(sc: SparkContext, dir: String) = {
val data = sc.textFile(dir).map{ line =>
/**
* @param sc SparkContext
* @param dir Directory to the input data files.
* @return An RDD of tuples. For each tuple, the first element is the label, and the second
* element represents the feature values (an array of Double).
*/
def loadData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = {
sc.textFile(dir).map { line =>
val parts = line.split(",")
val label = parts(0).toDouble
val features = parts(1).trim().split(" ").map(_.toDouble)
(label, features)
}
data
}
def saveData(data: RDD[(Double, Array[Double])], dir: String) {

View file

@ -7,47 +7,61 @@ import spark.SparkContext._
import org.jblas.DoubleMatrix
abstract class RegressionModel(
val weights: DoubleMatrix,
val intercept: Double) {
abstract class RegressionModel(val weights: DoubleMatrix, val intercept: Double) {
def predict(test_data: RDD[Array[Double]]): RDD[Double]
}
abstract class RegressionData(val data: RDD[(Double, Array[Double])]) extends Serializable {
var yMean: Double = 0.0
var xColMean: Array[Double] = null
var xColSd: Array[Double] = null
var nfeatures: Int = 0
var nexamples: Long = 0
val nfeatures: Int = data.take(1)(0)._2.length
val nexamples: Long = data.count()
// This will populate yMean, xColMean and xColSd
calculateStats()
val yMean: Double = data.map { case (y, features) => y }.reduce(_ + _) / nexamples
def normalizeData(): RDD[(Double, Array[Double])]
def scaleModel(model: RegressionModel): RegressionModel
def calculateStats() {
this.nexamples = data.count()
this.nfeatures = data.take(1)(0)._2.length
this.yMean = data.map { case (y, features) => y }.reduce(_ + _) / nexamples
// NOTE: We shuffle X by column here to compute column sum and sum of squares.
val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { case(y, features) =>
val nCols = features.length
// Traverse over every column and emit (col, value, value^2)
(0 until nCols).map(i => (i, (features(i), features(i)*features(i))))
}.reduceByKey { case(x1, x2) =>
(x1._1 + x2._1, x1._2 + x2._2)
// NOTE: We shuffle X by column here to compute column sum and sum of squares.
private val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { case(y, features) =>
val nCols = features.length
// Traverse over every column and emit (col, value, value^2)
Iterator.tabulate(nCols) { i =>
(i, (features(i), features(i)*features(i)))
}
val xColSumsMap = xColSumSq.collectAsMap()
// Compute mean and unbiased variance using column sums
this.xColMean = (0 until nfeatures).map(x => xColSumsMap(x)._1 / nexamples).toArray
this.xColSd = (0 until nfeatures).map {x =>
val v = (xColSumsMap(x)._2 - (math.pow(xColSumsMap(x)._1, 2) / nexamples)) / (nexamples)
math.sqrt(v)
}.toArray
}.reduceByKey { case(x1, x2) =>
(x1._1 + x2._1, x1._2 + x2._2)
}
private val xColSumsMap = xColSumSq.collectAsMap()
// Compute mean and unbiased variance using column sums
val xColMean: Array[Double] = Array.tabulate(nfeatures) { x =>
xColSumsMap(x)._1 / nexamples
}
val xColSd: Array[Double] = Array.tabulate(nfeatures) { x =>
val v = (xColSumsMap(x)._2 - (math.pow(xColSumsMap(x)._1, 2) / nexamples)) / (nexamples)
math.sqrt(v)
}
/**
* Normalize the provided input data. This function is typically called before
* training a classifier on the input dataset and should be used to center of scale the data
* appropriately.
*
* @return RDD containing the normalized data
*/
def normalizeData(): RDD[(Double, Array[Double])]
/**
* Scale the trained regression model. This function is usually called after training
* to adjust the model based on the normalization performed before.
*
* @return Regression model that can be used for prediction
*/
def scaleModel(model: RegressionModel): RegressionModel
}
trait Regression {
/**
* Train a model on the provided input dataset. Input data is an RDD of (Label, [Features])
*
* @return RegressionModel representing the model built.
*/
def train(input: RDD[(Double, Array[Double])]): RegressionModel
}

View file

@ -10,13 +10,16 @@ import org.jblas.Solve
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
class RidgeRegressionModel(
weights: DoubleMatrix,
intercept: Double,
val lambdaOpt: Double,
val lambdas: List[(Double, Double, DoubleMatrix)]) extends RegressionModel(weights, intercept) {
weights: DoubleMatrix,
intercept: Double,
val lambdaOpt: Double,
val lambdas: List[(Double, Double, DoubleMatrix)])
extends RegressionModel(weights, intercept) {
override def predict(test_data: spark.RDD[Array[Double]]) = {
test_data.map(x => (new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept)
override def predict(test_data: RDD[Array[Double]]) = {
test_data.map { x =>
(new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
}
}
}
@ -24,9 +27,9 @@ class RidgeRegressionData(data: RDD[(Double, Array[Double])]) extends Regression
override def normalizeData() = {
data.map { case(y, features) =>
val yNormalized = y - yMean
val featuresNormalized = (0 until nfeatures).map(
column => (features(column) - xColMean(column)) / xColSd(column)
).toArray
val featuresNormalized = Array.tabulate(nfeatures) { column =>
(features(column) - xColMean(column)) / xColSd(column)
}.toArray
(yNormalized, featuresNormalized)
}
}
@ -43,12 +46,9 @@ class RidgeRegressionData(data: RDD[(Double, Array[Double])]) extends Regression
}
}
object RidgeRegression extends Logging {
def train(inputData: RDD[(Double, Array[Double])],
lambdaLow: Double = 0.0,
lambdaHigh: Double = 10000.0) = {
class RidgeRegression(lambdaLow: Double, lambdaHigh: Double) extends Regression with Logging {
def train(inputData: RDD[(Double, Array[Double])]): RegressionModel = {
inputData.cache()
val ridgeData = new RidgeRegressionData(inputData)
val data = ridgeData.normalizeData()
@ -125,6 +125,22 @@ object RidgeRegression extends Logging {
normModel
}
}
/**
* Helper classes to build a RidgeRegression object.
*/
object RidgeRegression {
/**
* Build a RidgeRegression object with default arguments as:
*
* @param lowLambda as 0.0
* @param hiLambda as 100.0
*/
def builder() = {
new RidgeRegressionBuilder(0.0, 100.0)
}
def main(args: Array[String]) {
if (args.length != 2) {
@ -133,7 +149,36 @@ object RidgeRegression extends Logging {
}
val sc = new SparkContext(args(0), "RidgeRegression")
val data = MLUtils.loadData(sc, args(1))
val model = train(data, 0, 1000)
val ridgeReg = RidgeRegression.builder()
.setLowLambda(0)
.setHighLambda(1000)
.build()
val model = ridgeReg.train(data)
sc.stop()
}
}
class RidgeRegressionBuilder(lowLambda: Double, hiLambda: Double) {
/**
* Set the lower bound on binary search for lambda's. Default is 0.
*/
def setLowLambda(low: Double) = {
new RidgeRegressionBuilder(low, this.hiLambda)
}
/**
* Set the upper bound on binary search for lambda's. Default is 100.0.
*/
def setHighLambda(hi: Double) = {
new RidgeRegressionBuilder(this.lowLambda, hi)
}
/**
* Build a RidgeRegression object.
*/
def build() = {
new RidgeRegression(lowLambda, hiLambda)
}
}

View file

@ -3,13 +3,24 @@ package spark.ml
import org.jblas.DoubleMatrix
abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize and iteration number.
*
* @param weightsOld - Column matrix of size nx1 where n is the number of features.
* @param gradient - Column matrix of size nx1 where n is the number of features.
* @param stepSize - step size across iterations
* @param iter - Iteration number
*
* @return weightsNew - Column matrix containing updated weights
* @return reg_val - regularization value
*/
def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
(DoubleMatrix, Double)
}
class SimpleUpdater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
(DoubleMatrix, Double) = {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int): (DoubleMatrix, Double) = {
val normGradient = gradient.mul(stepSize / math.sqrt(iter))
(weightsOld.sub(normGradient), 0)
}