LocalSparkContext for MLlib
This commit is contained in:
parent
fe8a3546f4
commit
720836a761
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.mllib.classification
|
||||
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
|
@ -66,19 +67,7 @@ object LogisticRegressionSuite {
|
|||
|
||||
}
|
||||
|
||||
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
|
||||
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
||||
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
|
||||
prediction != expected.label
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.mllib.classification
|
||||
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
import scala.util.Random
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
@ -59,17 +60,7 @@ object NaiveBayesSuite {
|
|||
}
|
||||
}
|
||||
|
||||
class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class NaiveBayesSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
||||
val numOfPredictions = predictions.zip(input).count {
|
||||
|
|
|
@ -25,8 +25,9 @@ import org.scalatest.FunSuite
|
|||
|
||||
import org.jblas.DoubleMatrix
|
||||
|
||||
import org.apache.spark.{SparkException, SparkContext}
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.mllib.regression._
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
|
||||
object SVMSuite {
|
||||
|
||||
|
@ -58,17 +59,7 @@ object SVMSuite {
|
|||
|
||||
}
|
||||
|
||||
class SVMSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class SVMSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
||||
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
|
||||
|
|
|
@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
|
|||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
|
||||
|
||||
class KMeansSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class KMeansSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
val EPSILON = 1e-4
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
|
|||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.regression._
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
|
||||
object GradientDescentSuite {
|
||||
|
||||
|
@ -62,17 +63,7 @@ object GradientDescentSuite {
|
|||
}
|
||||
}
|
||||
|
||||
class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
|
||||
|
||||
test("Assert the loss is decreasing.") {
|
||||
val nPoints = 10000
|
||||
|
|
|
@ -23,7 +23,7 @@ import scala.util.Random
|
|||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
|
||||
import org.jblas._
|
||||
|
||||
|
@ -73,17 +73,7 @@ object ALSSuite {
|
|||
}
|
||||
|
||||
|
||||
class ALSSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class ALSSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
test("rank-1 matrices") {
|
||||
testALS(50, 100, 1, 15, 0.7, 0.3)
|
||||
|
|
|
@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
|
|||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.LinearDataGenerator
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
|
||||
|
||||
|
||||
class LassoSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class LassoSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
||||
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
|
||||
|
|
|
@ -21,19 +21,9 @@ import org.scalatest.BeforeAndAfterAll
|
|||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.LinearDataGenerator
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
|
||||
|
||||
class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
||||
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
|
||||
|
|
|
@ -23,19 +23,10 @@ import org.scalatest.BeforeAndAfterAll
|
|||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.util.LinearDataGenerator
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
|
||||
|
||||
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
|
||||
predictions.zip(input).map { case (prediction, expected) =>
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package org.apache.spark.mllib.util
|
||||
|
||||
import org.scalatest.Suite
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
|
||||
@transient var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
super.beforeAll()
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
if (sc != null) {
|
||||
sc.stop()
|
||||
}
|
||||
System.clearProperty("spark.driver.port")
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue