[SPARK-4979][MLLIB] Streaming logisitic regression

This adds support for streaming logistic regression with stochastic gradient descent, in the same manner as the existing implementation of streaming linear regression. It is a relatively simple addition because most of the work is already done by the abstract class `StreamingLinearAlgorithm` and existing algorithms and models from MLlib.

The PR includes
- Streaming Logistic Regression algorithm
- Unit tests for accuracy, streaming convergence, and streaming prediction
- An example use

cc mengxr tdas

Author: freeman <the.freeman.lab@gmail.com>

Closes #4306 from freeman-lab/streaming-logisitic-regression and squashes the following commits:

5c2c70b [freeman] Use Option on model
5cca2bc [freeman] Merge remote-tracking branch 'upstream/master' into streaming-logisitic-regression
275f8bd [freeman] Make private to mllib
3926e4e [freeman] Line formatting
5ee8694 [freeman] Experimental tag for docs
2fc68ac [freeman] Fix example formatting
85320b1 [freeman] Fixed line length
d88f717 [freeman] Remove stray comment
59d7ecb [freeman] Add streaming logistic regression
e78fe28 [freeman] Add streaming logistic regression example
321cc66 [freeman] Set private and protected within mllib
This commit is contained in:
freeman 2015-02-02 22:42:15 -08:00 committed by Xiangrui Meng
parent c306555f49
commit eb0da6c4bd
7 changed files with 327 additions and 27 deletions

View file

@ -35,8 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
*
* To run on your local machine using the two directories `trainingDir` and `testDir`,
* with updates every 5 seconds, and 2 features per data point, call:
* $ bin/run-example \
* org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2
* $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2
*
* As you add text files to `trainingDir` the model will continuously update.
* Anytime you add text files to `testDir`, you'll see predictions from the current model.

View file

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD
import org.apache.spark.SparkConf
import org.apache.spark.streaming.{Seconds, StreamingContext}
/**
* Train a logistic regression model on one stream of data and make predictions
* on another stream, where the data streams arrive as text files
* into two different directories.
*
* The rows of the text files must be labeled data points in the form
* `(y,[x1,x2,x3,...,xn])`
* Where n is the number of features, y is a binary label, and
* n must be the same for train and test.
*
* Usage: StreamingLogisticRegression <trainingDir> <testDir> <batchDuration> <numFeatures>
*
* To run on your local machine using the two directories `trainingDir` and `testDir`,
* with updates every 5 seconds, and 2 features per data point, call:
* $ bin/run-example mllib.StreamingLogisticRegression trainingDir testDir 5 2
*
* As you add text files to `trainingDir` the model will continuously update.
* Anytime you add text files to `testDir`, you'll see predictions from the current model.
*
*/
object StreamingLogisticRegression {
def main(args: Array[String]) {
if (args.length != 4) {
System.err.println(
"Usage: StreamingLogisticRegression <trainingDir> <testDir> <batchDuration> <numFeatures>")
System.exit(1)
}
val conf = new SparkConf().setMaster("local").setAppName("StreamingLogisticRegression")
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse)
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
val model = new StreamingLogisticRegressionWithSGD()
.setInitialWeights(Vectors.zeros(args(3).toInt))
model.trainOn(trainingData)
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
ssc.start()
ssc.awaitTermination()
}
}

View file

@ -136,7 +136,7 @@ class LogisticRegressionModel (
* for k classes multi-label classification problem.
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/
class LogisticRegressionWithSGD private (
class LogisticRegressionWithSGD private[mllib] (
private var stepSize: Double,
private var numIterations: Int,
private var regParam: Double,
@ -158,7 +158,7 @@ class LogisticRegressionWithSGD private (
*/
def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
}
}

View file

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.StreamingLinearAlgorithm
/**
* :: Experimental ::
* Train or predict a logistic regression model on streaming data. Training uses
* Stochastic Gradient Descent to update the model based on each new batch of
* incoming data from a DStream (see `LogisticRegressionWithSGD` for model equation)
*
* Each batch of data is assumed to be an RDD of LabeledPoints.
* The number of data points per batch can vary, but the number
* of features must be constant. An initial weight
* vector must be provided.
*
* Use a builder pattern to construct a streaming logistic regression
* analysis in an application, like:
*
* val model = new StreamingLogisticRegressionWithSGD()
* .setStepSize(0.5)
* .setNumIterations(10)
* .setInitialWeights(Vectors.dense(...))
* .trainOn(DStream)
*
*/
@Experimental
class StreamingLogisticRegressionWithSGD private[mllib] (
private var stepSize: Double,
private var numIterations: Int,
private var miniBatchFraction: Double,
private var regParam: Double)
extends StreamingLinearAlgorithm[LogisticRegressionModel, LogisticRegressionWithSGD]
with Serializable {
/**
* Construct a StreamingLogisticRegression object with default parameters:
* {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0, regParam: 0.0}.
* Initial weights must be set before using trainOn or predictOn
* (see `StreamingLinearAlgorithm`)
*/
def this() = this(0.1, 50, 1.0, 0.0)
val algorithm = new LogisticRegressionWithSGD(
stepSize, numIterations, regParam, miniBatchFraction)
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
this
}
/** Set the number of iterations of gradient descent to run per update. Default: 50. */
def setNumIterations(numIterations: Int): this.type = {
this.algorithm.optimizer.setNumIterations(numIterations)
this
}
/** Set the fraction of each batch to use for updates. Default: 1.0. */
def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction)
this
}
/** Set the regularization parameter. Default: 0.0. */
def setRegParam(regParam: Double): this.type = {
this.algorithm.optimizer.setRegParam(regParam)
this
}
/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Option(algorithm.createModel(initialWeights, 0.0))
this
}
}

View file

@ -58,14 +58,14 @@ abstract class StreamingLinearAlgorithm[
A <: GeneralizedLinearAlgorithm[M]] extends Logging {
/** The model to be updated and used for prediction. */
protected var model: M
protected var model: Option[M] = null
/** The algorithm to use for updating. */
protected val algorithm: A
/** Return the latest model. */
def latestModel(): M = {
model
model.get
}
/**
@ -77,16 +77,16 @@ abstract class StreamingLinearAlgorithm[
* @param data DStream containing labeled data
*/
def trainOn(data: DStream[LabeledPoint]) {
if (Option(model.weights) == None) {
logError("Initial weights must be set before starting training")
if (Option(model) == None) {
logError("Model must be initialized before starting training")
throw new IllegalArgumentException
}
data.foreachRDD { (rdd, time) =>
model = algorithm.run(rdd, model.weights)
model = Option(algorithm.run(rdd, model.get.weights))
logInfo("Model updated at time %s".format(time.toString))
val display = model.weights.size match {
case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...")
case _ => model.weights.toArray.mkString("[", ",", "]")
val display = model.get.weights.size match {
case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
case _ => model.get.weights.toArray.mkString("[", ",", "]")
}
logInfo("Current model: weights, %s".format (display))
}
@ -99,12 +99,12 @@ abstract class StreamingLinearAlgorithm[
* @return DStream containing predictions
*/
def predictOn(data: DStream[Vector]): DStream[Double] = {
if (Option(model.weights) == None) {
val msg = "Initial weights must be set before starting prediction"
if (Option(model) == None) {
val msg = "Model must be initialized before starting prediction"
logError(msg)
throw new IllegalArgumentException(msg)
}
data.map(model.predict)
data.map(model.get.predict)
}
/**
@ -114,11 +114,11 @@ abstract class StreamingLinearAlgorithm[
* @return DStream containing the input keys and the predictions as values
*/
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
if (Option(model.weights) == None) {
val msg = "Initial weights must be set before starting prediction"
if (Option(model) == None) {
val msg = "Model must be initialized before starting prediction"
logError(msg)
throw new IllegalArgumentException(msg)
}
data.mapValues(model.predict)
data.mapValues(model.get.predict)
}
}

View file

@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.Vector
/**
* :: Experimental ::
* Train or predict a linear regression model on streaming data. Training uses
* Stochastic Gradient Descent to update the model based on each new batch of
* incoming data from a DStream (see `LinearRegressionWithSGD` for model equation)
@ -41,13 +42,12 @@ import org.apache.spark.mllib.linalg.Vector
*
*/
@Experimental
class StreamingLinearRegressionWithSGD (
class StreamingLinearRegressionWithSGD private[mllib] (
private var stepSize: Double,
private var numIterations: Int,
private var miniBatchFraction: Double,
private var initialWeights: Vector)
extends StreamingLinearAlgorithm[
LinearRegressionModel, LinearRegressionWithSGD] with Serializable {
private var miniBatchFraction: Double)
extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD]
with Serializable {
/**
* Construct a StreamingLinearRegression object with default parameters:
@ -55,12 +55,10 @@ class StreamingLinearRegressionWithSGD (
* Initial weights must be set before using trainOn or predictOn
* (see `StreamingLinearAlgorithm`)
*/
def this() = this(0.1, 50, 1.0, null)
def this() = this(0.1, 50, 1.0)
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
var model = algorithm.createModel(initialWeights, 0.0)
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
@ -81,7 +79,7 @@ class StreamingLinearRegressionWithSGD (
/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = algorithm.createModel(initialWeights, 0.0)
this.model = Option(algorithm.createModel(initialWeights, 0.0))
this
}

View file

@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.classification
import scala.collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis = 30000
// Test if we can accurately learn B for Y = logistic(BX) on streaming data
test("parameter accuracy") {
val nPoints = 100
val B = 1.5
// create model
val model = new StreamingLogisticRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
.setStepSize(0.2)
.setNumIterations(25)
// generate sequence of simulated data
val numBatches = 20
val input = (0 until numBatches).map { i =>
LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1))
}
// apply model training to input stream
val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
runStreams(ssc, numBatches, numBatches)
// check accuracy of final parameter estimates
assert(model.latestModel().weights(0) ~== B relTol 0.1)
}
// Test that parameter estimates improve when learning Y = logistic(BX) on streaming data
test("parameter convergence") {
val B = 1.5
val nPoints = 100
// create model
val model = new StreamingLogisticRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
.setStepSize(0.2)
.setNumIterations(25)
// generate sequence of simulated data
val numBatches = 20
val input = (0 until numBatches).map { i =>
LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1))
}
// create buffer to store intermediate fits
val history = new ArrayBuffer[Double](numBatches)
// apply model training to input stream, storing the intermediate results
// (we add a count to ensure the result is a DStream)
val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B)))
inputDStream.count()
})
runStreams(ssc, numBatches, numBatches)
// compute change in error
val deltas = history.drop(1).zip(history.dropRight(1))
// check error stability (it always either shrinks, or increases with small tol)
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
// check that error shrunk on at least 2 batches
assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1)
}
// Test predictions on a stream
test("predictions") {
val B = 1.5
val nPoints = 100
// create model initialized with true weights
val model = new StreamingLogisticRegressionWithSGD()
.setInitialWeights(Vectors.dense(1.5))
.setStepSize(0.2)
.setNumIterations(25)
// generate sequence of simulated data for testing
val numBatches = 10
val testInput = (0 until numBatches).map { i =>
LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1))
}
// apply model predictions to test stream
val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
// collect the output as (true, estimated) tuples
val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
// check that at least 60% of predictions are correct on all batches
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
assert(errors.forall(x => x <= 0.4))
}
}