SVMSuite and LassoSuite rewritten to follow closely with LogisticRegressionSuite
This commit is contained in:
parent
29e042940a
commit
67de051bbb
|
@ -1,3 +1,20 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.mllib.classification
|
||||
|
||||
import scala.util.Random
|
||||
|
@ -7,7 +24,6 @@ import org.scalatest.BeforeAndAfterAll
|
|||
import org.scalatest.FunSuite
|
||||
|
||||
import spark.SparkContext
|
||||
import spark.SparkContext._
|
||||
|
||||
import java.io._
|
||||
|
||||
|
@ -19,43 +35,82 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
|
|||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
|
||||
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
|
||||
def generateSVMInput(
|
||||
intercept: Double,
|
||||
weights: Array[Double],
|
||||
nPoints: Int,
|
||||
seed: Int): Seq[(Double, Array[Double])] = {
|
||||
val rnd = new Random(seed)
|
||||
val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
|
||||
val y = x.map(xi =>
|
||||
signum((xi zip weights).map(xw => xw._1*xw._2).reduce(_+_) + intercept + 0.1 * rnd.nextGaussian())
|
||||
)
|
||||
y zip x
|
||||
}
|
||||
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
|
||||
val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
|
||||
// A prediction is off if the prediction is more than 0.5 away from expected value.
|
||||
math.abs(prediction - expected) > 0.5
|
||||
}.size
|
||||
// At least 80% of the predictions should be on.
|
||||
assert(numOffPredictions < input.length / 5)
|
||||
}
|
||||
|
||||
test("SVMLocalRandomSGD") {
|
||||
val nPoints = 10000
|
||||
val rnd = new Random(42)
|
||||
|
||||
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
|
||||
val x2 = Array.fill[Double](nPoints)(rnd.nextGaussian())
|
||||
|
||||
val A = 2.0
|
||||
val B = -1.5
|
||||
val C = 1.0
|
||||
|
||||
val y = (0 until nPoints).map { i =>
|
||||
signum(A + B * x1(i) + C * x2(i) + 0.0*rnd.nextGaussian())
|
||||
}
|
||||
|
||||
val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i),x2(i)))).toArray
|
||||
val testData = generateSVMInput(A, Array[Double](B,C), nPoints, 42)
|
||||
|
||||
val testRDD = sc.parallelize(testData, 2)
|
||||
testRDD.cache()
|
||||
|
||||
val writer_data = new PrintWriter(new File("svmtest.dat"))
|
||||
testData.foreach(yx => {
|
||||
writer_data.write(yx._1 + "")
|
||||
yx._2.foreach(xi => writer_data.write("\t" + xi))
|
||||
writer_data.write("\n")})
|
||||
writer_data.close()
|
||||
|
||||
val svm = new SVMLocalRandomSGD().setStepSize(1.0)
|
||||
.setRegParam(1.0)
|
||||
.setNumIterations(100)
|
||||
val svm = new SVMLocalRandomSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
|
||||
|
||||
val model = svm.train(testRDD)
|
||||
|
||||
val yPredict = (0 until nPoints).map(i => model.predict(Array(x1(i),x2(i))))
|
||||
val validationData = generateSVMInput(A, Array[Double](B,C), nPoints, 17)
|
||||
val validationRDD = sc.parallelize(validationData,2)
|
||||
|
||||
val accuracy = ((y zip yPredict).map(yy => if (yy._1==yy._2) 1 else 0).reduceLeft(_+_).toDouble / nPoints.toDouble)
|
||||
// Test prediction on RDD.
|
||||
validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
|
||||
|
||||
assert(accuracy >= 0.90, "Accuracy (" + accuracy + ") too low")
|
||||
// Test prediction on Array.
|
||||
validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
|
||||
}
|
||||
|
||||
test("SVMLocalRandomSGD with initial weights") {
|
||||
val nPoints = 10000
|
||||
|
||||
val A = 2.0
|
||||
val B = -1.5
|
||||
val C = 1.0
|
||||
|
||||
val testData = generateSVMInput(A, Array[Double](B,C), nPoints, 42)
|
||||
|
||||
val initialB = -1.0
|
||||
val initialC = -1.0
|
||||
val initialWeights = Array(initialB,initialC)
|
||||
|
||||
val testRDD = sc.parallelize(testData, 2)
|
||||
testRDD.cache()
|
||||
|
||||
val svm = new SVMLocalRandomSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
|
||||
|
||||
val model = svm.train(testRDD, initialWeights)
|
||||
|
||||
val validationData = generateSVMInput(A, Array[Double](B,C), nPoints, 17)
|
||||
val validationRDD = sc.parallelize(validationData,2)
|
||||
|
||||
// Test prediction on RDD.
|
||||
validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
|
||||
|
||||
// Test prediction on Array.
|
||||
validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,20 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.mllib.regression
|
||||
|
||||
import scala.util.Random
|
||||
|
@ -6,7 +23,6 @@ import org.scalatest.BeforeAndAfterAll
|
|||
import org.scalatest.FunSuite
|
||||
|
||||
import spark.SparkContext
|
||||
import spark.SparkContext._
|
||||
|
||||
|
||||
class LassoSuite extends FunSuite with BeforeAndAfterAll {
|
||||
|
@ -17,35 +33,90 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
|
|||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
|
||||
// Generate noisy input of the form Y = x.dot(weights) + intercept + noise
|
||||
def generateLassoInput(
|
||||
intercept: Double,
|
||||
weights: Array[Double],
|
||||
nPoints: Int,
|
||||
seed: Int): Seq[(Double, Array[Double])] = {
|
||||
val rnd = new Random(seed)
|
||||
val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
|
||||
val y = x.map(xi => (xi zip weights).map(xw => xw._1*xw._2).reduce(_+_) + intercept + 0.1 * rnd.nextGaussian())
|
||||
y zip x
|
||||
}
|
||||
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
|
||||
val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
|
||||
// A prediction is off if the prediction is more than 0.5 away from expected value.
|
||||
math.abs(prediction - expected) > 0.5
|
||||
}.size
|
||||
// At least 80% of the predictions should be on.
|
||||
assert(numOffPredictions < input.length / 5)
|
||||
}
|
||||
|
||||
test("LassoLocalRandomSGD") {
|
||||
val nPoints = 10000
|
||||
val rnd = new Random(42)
|
||||
|
||||
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
|
||||
val x2 = Array.fill[Double](nPoints)(rnd.nextGaussian())
|
||||
|
||||
val A = 2.0
|
||||
val B = -1.5
|
||||
val C = 1.0e-2
|
||||
|
||||
val y = (0 until nPoints).map { i =>
|
||||
A + B * x1(i) + C * x2(i) + 0.1*rnd.nextGaussian()
|
||||
}
|
||||
|
||||
val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i),x2(i)))).toArray
|
||||
val testData = generateLassoInput(A, Array[Double](B,C), nPoints, 42)
|
||||
|
||||
val testRDD = sc.parallelize(testData, 2)
|
||||
testRDD.cache()
|
||||
val ls = new LassoLocalRandomSGD().setStepSize(1.0)
|
||||
.setRegParam(0.01)
|
||||
.setNumIterations(20)
|
||||
val ls = new LassoLocalRandomSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
|
||||
|
||||
val model = ls.train(testRDD)
|
||||
|
||||
val weight0 = model.weights(0)
|
||||
val weight1 = model.weights(1)
|
||||
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
|
||||
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
|
||||
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
|
||||
|
||||
val validationData = generateLassoInput(A, Array[Double](B,C), nPoints, 17)
|
||||
val validationRDD = sc.parallelize(validationData,2)
|
||||
|
||||
// Test prediction on RDD.
|
||||
validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
|
||||
|
||||
// Test prediction on Array.
|
||||
validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
|
||||
}
|
||||
|
||||
test("LassoLocalRandomSGD with initial weights") {
|
||||
val nPoints = 10000
|
||||
|
||||
val A = 2.0
|
||||
val B = -1.5
|
||||
val C = 1.0e-2
|
||||
|
||||
val testData = generateLassoInput(A, Array[Double](B,C), nPoints, 42)
|
||||
|
||||
val initialB = -1.0
|
||||
val initialC = -1.0
|
||||
val initialWeights = Array(initialB,initialC)
|
||||
|
||||
val testRDD = sc.parallelize(testData, 2)
|
||||
testRDD.cache()
|
||||
val ls = new LassoLocalRandomSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
|
||||
|
||||
val model = ls.train(testRDD, initialWeights)
|
||||
|
||||
val weight0 = model.weights(0)
|
||||
val weight1 = model.weights(1)
|
||||
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
|
||||
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
|
||||
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
|
||||
|
||||
val validationData = generateLassoInput(A, Array[Double](B,C), nPoints, 17)
|
||||
val validationRDD = sc.parallelize(validationData,2)
|
||||
|
||||
// Test prediction on RDD.
|
||||
validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
|
||||
|
||||
// Test prediction on Array.
|
||||
validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue