Import RidgeRegression example

Conflicts:
	run
This commit is contained in:
Matei Zaharia 2013-07-05 11:13:41 -07:00
parent 6ad85d0918
commit 729e463f64
3 changed files with 191 additions and 1 deletions

View file

@ -0,0 +1,110 @@
package spark.ml
import spark._
import spark.SparkContext._
import org.apache.commons.math3.distribution.NormalDistribution
import org.jblas.DoubleMatrix
import org.jblas.Solve
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
class RidgeRegressionModel(
val wOpt: DoubleMatrix,
val lambdaOpt: Double,
val lambdas: List[(Double, Double, DoubleMatrix)]) {
def predict(test_data: spark.RDD[Array[Double]]) = {
test_data.map(x => new DoubleMatrix(1, x.length, x:_*).mmul(this.wOpt))
}
}
object RidgeRegression extends Logging {
def train(data: spark.RDD[(Double, Array[Double])],
lambdaLow: Double = 0.0,
lambdaHigh: Double = 10000.0) = {
data.cache()
val nfeatures = data.take(1)(0)._2.length
val nexamples = data.count
// Compute XtX - Size of XtX is nfeatures by nfeatures
val XtX = data.map {
case (y, features) =>
val x = new DoubleMatrix(1, features.length, features:_*)
x.transpose().mmul(x)
}.reduce(_.add(_))
// Compute Xt*y - Size of Xty is nfeatures by 1
val Xty = data.map {
case (y, features) =>
new DoubleMatrix(features.length, 1, features:_*).mul(y)
}.reduce(_.add(_))
// Define a function to compute the leave one out cross validation error
// for a single example
def crossValidate(lambda: Double) = {
// Compute the MLE ridge regression parameter value
// Ridge Regression parameter = inv(XtX + \lambda*I) * Xty
val XtXlambda = DoubleMatrix.eye(nfeatures).muli(lambda).addi(XtX)
val w = Solve.solveSymmetric(XtXlambda, Xty)
val invXtX = Solve.solveSymmetric(XtXlambda,
DoubleMatrix.eye(nfeatures))
// compute the leave one out cross validation score
val cvError = data.map {
case (y, features) =>
val x = new DoubleMatrix(features.length, 1, features:_*)
val yhat = w.transpose().mmul(x).get(0)
val H_ii = x.transpose().mmul(invXtX).mmul(x).get(0)
val residual = (y - yhat) / (1.0 - H_ii)
residual * residual
}.reduce(_ + _)
(lambda, cvError, w)
}
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
lowValue :: highValue :: binSearch(newLow, newHigh)
} else {
List(lowValue, highValue)
}
}
// Actually compute the best lambda
val lambdas = binSearch(lambdaLow, lambdaHigh).sortBy(_._1)
// Find the best parameter set
val (lambdaOpt, cverror, wOpt) = lambdas.reduce((a, b) => if (a._2 < b._2) a else b)
logInfo("RidgeRegression: optimal lambda " + lambdaOpt)
// Return the model which contains the solution
new RidgeRegressionModel(wOpt, lambdaOpt, lambdas)
}
def main(args: Array[String]) {
if (args.length != 2) {
println("Usage: RidgeRegression <master> <input_dir>")
System.exit(1)
}
val sc = new SparkContext(args(0), "RidgeRegression")
val data = RidgeRegressionGenerator.loadData(sc, args(1))
val model = train(data, 0, 100)
sc.stop()
}
}

View file

@ -0,0 +1,70 @@
package spark.ml
import spark._
import spark.SparkContext._
import org.apache.commons.math3.distribution.NormalDistribution
import org.jblas.DoubleMatrix
object RidgeRegressionGenerator {
// Helper methods to load and save data used for RidgeRegression
// Data format:
// <l>, <f1> <f2> ...
// where <f1>, <f2> are feature values in Double and
// <l> is the corresponding label as Double
def loadData(sc: SparkContext, dir: String) = {
val data = sc.textFile(dir).map{ line =>
val parts = line.split(",")
val label = parts(0).toDouble
val features = parts(1).trim().split(" ").map(_.toDouble)
(label, features)
}
data
}
def saveData(data: RDD[(Double, Array[Double])], dir: String) {
val dataStr = data.map(x => x._1 + "," + x._2.mkString(" "))
dataStr.saveAsTextFile(dir)
}
def main(args: Array[String]) {
if (args.length != 2) {
println("Usage: RidgeRegressionGenerator <master> <output_dir>")
System.exit(1)
}
org.jblas.util.Random.seed(42)
val sc = new SparkContext(args(0), "RidgeRegressionGenerator")
val nexamples = 1000
val nfeatures = 100
val eps = 10
val parts = 2
// Random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
w.put(0, 0, 10)
w.put(1, 0, 10)
val data = sc.parallelize(0 until parts, parts).flatMap { p =>
org.jblas.util.Random.seed(42 + p)
val examplesInPartition = nexamples / parts
val X = DoubleMatrix.rand(examplesInPartition, nfeatures)
val y = X.mmul(w)
val rnd = new NormalDistribution(0, eps)
rnd.reseedRandomGenerator(42 + p)
val normalValues = (0 until examplesInPartition).map(_ => rnd.sample())
val yObs = new DoubleMatrix(examplesInPartition, 1, normalValues:_*).addi(y)
(0 until examplesInPartition).map(i =>
(yObs.get(i, 0), X.getRow(i).toArray)
)
}
saveData(data, args(1))
System.exit(0)
}
}

View file

@ -25,7 +25,7 @@ object SparkBuild extends Build {
//val HADOOP_MAJOR_VERSION = "2"
//val HADOOP_YARN = true
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming, ml)
lazy val core = Project("core", file("core"), settings = coreSettings)
@ -37,6 +37,8 @@ object SparkBuild extends Build {
lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core)
lazy val ml = Project("ml", file("ml"), settings = mlSettings) dependsOn (core)
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@ -219,6 +221,14 @@ object SparkBuild extends Build {
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
def mlSettings = examplesSettings ++ Seq(
name := "spark-ml",
libraryDependencies ++= Seq(
"org.jblas" % "jblas" % "1.2.3",
"org.apache.commons" % "commons-math3" % "3.2"
)
)
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(