LocalSparkContext for MLlib

This commit is contained in:
Andrew Tulloch 2014-01-19 17:51:00 +00:00
parent fe8a3546f4
commit 720836a761
10 changed files with 42 additions and 109 deletions

View file

@ -17,6 +17,7 @@
package org.apache.spark.mllib.classification
import org.apache.spark.mllib.util.LocalSparkContext
import scala.util.Random
import scala.collection.JavaConversions._
@ -66,19 +67,7 @@ object LogisticRegressionSuite {
}
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label

View file

@ -17,6 +17,7 @@
package org.apache.spark.mllib.classification
import org.apache.spark.mllib.util.LocalSparkContext
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
@ -59,17 +60,7 @@ object NaiveBayesSuite {
}
}
class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class NaiveBayesSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count {

View file

@ -25,8 +25,9 @@ import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
import org.apache.spark.{SparkException, SparkContext}
import org.apache.spark.SparkException
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
object SVMSuite {
@ -58,17 +59,7 @@ object SVMSuite {
}
class SVMSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class SVMSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

View file

@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LocalSparkContext
class KMeansSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class KMeansSuite extends FunSuite with LocalSparkContext {
val EPSILON = 1e-4

View file

@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
object GradientDescentSuite {
@ -62,17 +63,7 @@ object GradientDescentSuite {
}
}
class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000

View file

@ -23,7 +23,7 @@ import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LocalSparkContext
import org.jblas._
@ -73,17 +73,7 @@ object ALSSuite {
}
class ALSSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class ALSSuite extends FunSuite with LocalSparkContext {
test("rank-1 matrices") {
testALS(50, 100, 1, 15, 0.7, 0.3)

View file

@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class LassoSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

View file

@ -21,19 +21,9 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

View file

@ -23,19 +23,10 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
predictions.zip(input).map { case (prediction, expected) =>

View file

@ -0,0 +1,23 @@
package org.apache.spark.mllib.util
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkContext
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
super.beforeAll()
}
override def afterAll() {
if (sc != null) {
sc.stop()
}
System.clearProperty("spark.driver.port")
super.afterAll()
}
}