SPARK-1791 - SVM implementation does not use threshold parameter
Summary: https://issues.apache.org/jira/browse/SPARK-1791 Simple fix, and backward compatible, since - anyone who set the threshold was getting completely wrong answers. - anyone who did not set the threshold had the default 0.0 value for the threshold anyway. Test Plan: Unit test added that is verified to fail under the old implementation, and pass under the new implementation. Reviewers: CC: Author: Andrew Tulloch <andrew@tullo.ch> Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits: 770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter
This commit is contained in:
parent
16ffadcc4a
commit
d1e487473f
|
@ -65,7 +65,7 @@ class SVMModel private[mllib] (
|
|||
intercept: Double) = {
|
||||
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
|
||||
threshold match {
|
||||
case Some(t) => if (margin < 0) 0.0 else 1.0
|
||||
case Some(t) => if (margin < t) 0.0 else 1.0
|
||||
case None => margin
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
|
|||
assert(numOffPredictions < input.length / 5)
|
||||
}
|
||||
|
||||
test("SVM with threshold") {
|
||||
val nPoints = 10000
|
||||
|
||||
// NOTE: Intercept should be small for generating equal 0s and 1s
|
||||
val A = 0.01
|
||||
val B = -1.5
|
||||
val C = 1.0
|
||||
|
||||
val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
|
||||
|
||||
val testRDD = sc.parallelize(testData, 2)
|
||||
testRDD.cache()
|
||||
|
||||
val svm = new SVMWithSGD().setIntercept(true)
|
||||
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
|
||||
|
||||
val model = svm.run(testRDD)
|
||||
|
||||
val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
|
||||
val validationRDD = sc.parallelize(validationData, 2)
|
||||
|
||||
// Test prediction on RDD.
|
||||
|
||||
var predictions = model.predict(validationRDD.map(_.features)).collect()
|
||||
assert(predictions.count(_ == 0.0) != predictions.length)
|
||||
|
||||
// High threshold makes all the predictions 0.0
|
||||
model.setThreshold(10000.0)
|
||||
predictions = model.predict(validationRDD.map(_.features)).collect()
|
||||
assert(predictions.count(_ == 0.0) == predictions.length)
|
||||
|
||||
// Low threshold makes all the predictions 1.0
|
||||
model.setThreshold(-10000.0)
|
||||
predictions = model.predict(validationRDD.map(_.features)).collect()
|
||||
assert(predictions.count(_ == 1.0) == predictions.length)
|
||||
}
|
||||
|
||||
test("SVM using local random SGD") {
|
||||
val nPoints = 10000
|
||||
|
||||
|
|
Loading…
Reference in a new issue