SPARK-1791 - SVM implementation does not use threshold parameter

Summary:
https://issues.apache.org/jira/browse/SPARK-1791

Simple fix, and backward compatible, since

- anyone who set the threshold was getting completely wrong answers.
- anyone who did not set the threshold had the default 0.0 value for the threshold anyway.

Test Plan:
Unit test added that is verified to fail under the old implementation,
and pass under the new implementation.

Reviewers:

CC:

Author: Andrew Tulloch <andrew@tullo.ch>

Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits:

770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter
This commit is contained in:
Andrew Tulloch 2014-05-13 17:31:27 -07:00 committed by Reynold Xin
parent 16ffadcc4a
commit d1e487473f
2 changed files with 38 additions and 1 deletions

View file

@ -65,7 +65,7 @@ class SVMModel private[mllib] (
intercept: Double) = { intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match { threshold match {
case Some(t) => if (margin < 0) 0.0 else 1.0 case Some(t) => if (margin < t) 0.0 else 1.0
case None => margin case None => margin
} }
} }

View file

@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
assert(numOffPredictions < input.length / 5) assert(numOffPredictions < input.length / 5)
} }
test("SVM with threshold") {
val nPoints = 10000
// NOTE: Intercept should be small for generating equal 0s and 1s
val A = 0.01
val B = -1.5
val C = 1.0
val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val svm = new SVMWithSGD().setIntercept(true)
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
val model = svm.run(testRDD)
val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
var predictions = model.predict(validationRDD.map(_.features)).collect()
assert(predictions.count(_ == 0.0) != predictions.length)
// High threshold makes all the predictions 0.0
model.setThreshold(10000.0)
predictions = model.predict(validationRDD.map(_.features)).collect()
assert(predictions.count(_ == 0.0) == predictions.length)
// Low threshold makes all the predictions 1.0
model.setThreshold(-10000.0)
predictions = model.predict(validationRDD.map(_.features)).collect()
assert(predictions.count(_ == 1.0) == predictions.length)
}
test("SVM using local random SGD") { test("SVM using local random SGD") {
val nPoints = 10000 val nPoints = 10000