SPARK-1791 - SVM implementation does not use threshold parameter
Summary: https://issues.apache.org/jira/browse/SPARK-1791 Simple fix, and backward compatible, since - anyone who set the threshold was getting completely wrong answers. - anyone who did not set the threshold had the default 0.0 value for the threshold anyway. Test Plan: Unit test added that is verified to fail under the old implementation, and pass under the new implementation. Reviewers: CC: Author: Andrew Tulloch <andrew@tullo.ch> Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits: 770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter
This commit is contained in:
parent
16ffadcc4a
commit
d1e487473f
|
@ -65,7 +65,7 @@ class SVMModel private[mllib] (
|
||||||
intercept: Double) = {
|
intercept: Double) = {
|
||||||
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
|
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
|
||||||
threshold match {
|
threshold match {
|
||||||
case Some(t) => if (margin < 0) 0.0 else 1.0
|
case Some(t) => if (margin < t) 0.0 else 1.0
|
||||||
case None => margin
|
case None => margin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
|
||||||
assert(numOffPredictions < input.length / 5)
|
assert(numOffPredictions < input.length / 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SVM with threshold") {
|
||||||
|
val nPoints = 10000
|
||||||
|
|
||||||
|
// NOTE: Intercept should be small for generating equal 0s and 1s
|
||||||
|
val A = 0.01
|
||||||
|
val B = -1.5
|
||||||
|
val C = 1.0
|
||||||
|
|
||||||
|
val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
|
||||||
|
|
||||||
|
val testRDD = sc.parallelize(testData, 2)
|
||||||
|
testRDD.cache()
|
||||||
|
|
||||||
|
val svm = new SVMWithSGD().setIntercept(true)
|
||||||
|
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
|
||||||
|
|
||||||
|
val model = svm.run(testRDD)
|
||||||
|
|
||||||
|
val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
|
||||||
|
val validationRDD = sc.parallelize(validationData, 2)
|
||||||
|
|
||||||
|
// Test prediction on RDD.
|
||||||
|
|
||||||
|
var predictions = model.predict(validationRDD.map(_.features)).collect()
|
||||||
|
assert(predictions.count(_ == 0.0) != predictions.length)
|
||||||
|
|
||||||
|
// High threshold makes all the predictions 0.0
|
||||||
|
model.setThreshold(10000.0)
|
||||||
|
predictions = model.predict(validationRDD.map(_.features)).collect()
|
||||||
|
assert(predictions.count(_ == 0.0) == predictions.length)
|
||||||
|
|
||||||
|
// Low threshold makes all the predictions 1.0
|
||||||
|
model.setThreshold(-10000.0)
|
||||||
|
predictions = model.predict(validationRDD.map(_.features)).collect()
|
||||||
|
assert(predictions.count(_ == 1.0) == predictions.length)
|
||||||
|
}
|
||||||
|
|
||||||
test("SVM using local random SGD") {
|
test("SVM using local random SGD") {
|
||||||
val nPoints = 10000
|
val nPoints = 10000
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue