diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 02ede71137..155788f85e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.classification +import org.apache.spark.mllib.util.LocalSparkContext import scala.util.Random import scala.collection.JavaConversions._ @@ -66,19 +67,7 @@ object LogisticRegressionSuite { } -class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - +class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index b615f76e66..375196d97f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.classification +import org.apache.spark.mllib.util.LocalSparkContext import scala.util.Random import org.scalatest.BeforeAndAfterAll @@ -59,17 +60,7 @@ object NaiveBayesSuite { } } -class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class NaiveBayesSuite extends FunSuite with LocalSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 3357b86f9b..bc7abb568a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -25,8 +25,9 @@ import org.scalatest.FunSuite import org.jblas.DoubleMatrix -import org.apache.spark.{SparkException, SparkContext} +import org.apache.spark.SparkException import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.LocalSparkContext object SVMSuite { @@ -58,17 +59,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class SVMSuite extends FunSuite with LocalSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 73657cac89..4ef1d1f64f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.mllib.util.LocalSparkContext - -class KMeansSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class KMeansSuite extends FunSuite with LocalSparkContext { val EPSILON = 1e-4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index a6028a1e98..a453de6767 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers import org.apache.spark.SparkContext import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.LocalSparkContext object GradientDescentSuite { @@ -62,17 +63,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers { test("Assert the loss is decreasing.") { val nPoints = 10000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 4e8dbde658..28a27b1d09 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.mllib.util.LocalSparkContext import org.jblas._ @@ -73,17 +73,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class ALSSuite extends FunSuite with LocalSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index b2c8df97a8..64e4cbb860 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.SparkContext -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} - -class LassoSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class LassoSuite extends FunSuite with LocalSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 406afbaa3e..648d89bc8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -21,19 +21,9 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.SparkContext -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} -class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class LinearRegressionSuite extends FunSuite with LocalSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 1d6a10b66e..7ec53eb772 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -23,19 +23,10 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.SparkContext -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} -class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } +class RidgeRegressionSuite extends FunSuite with LocalSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { predictions.zip(input).map { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala new file mode 100644 index 0000000000..7d840043e5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala @@ -0,0 +1,23 @@ +package org.apache.spark.mllib.util + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkContext + +trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => + @transient var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + super.beforeAll() + } + + override def afterAll() { + if (sc != null) { + sc.stop() + } + System.clearProperty("spark.driver.port") + super.afterAll() + } +}