diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index e05213536e..316ecd713b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -65,7 +65,7 @@ class SVMModel private[mllib] ( intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept threshold match { - case Some(t) => if (margin < 0) 0.0 else 1.0 + case Some(t) => if (margin < t) 0.0 else 1.0 case None => margin } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 77d6f04b32..886c71dde3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext { assert(numOffPredictions < input.length / 5) } + test("SVM with threshold") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val svm = new SVMWithSGD().setIntercept(true) + svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) + + val model = svm.run(testRDD) + + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + + var predictions = model.predict(validationRDD.map(_.features)).collect() + assert(predictions.count(_ == 0.0) != predictions.length) + + // High threshold makes all the predictions 0.0 + model.setThreshold(10000.0) + predictions = model.predict(validationRDD.map(_.features)).collect() + assert(predictions.count(_ == 0.0) == predictions.length) + + // Low threshold makes all the predictions 1.0 + model.setThreshold(-10000.0) + predictions = model.predict(validationRDD.map(_.features)).collect() + assert(predictions.count(_ == 1.0) == predictions.length) + } + test("SVM using local random SGD") { val nPoints = 10000