[SPARK-8054] [MLLIB] Added several Java-friendly APIs + unit tests
Java-friendly APIs added:
* GaussianMixture.run()
* GaussianMixtureModel.predict()
* DistributedLDAModel.javaTopicDistributions()
* StreamingKMeans: trainOn, predictOn, predictOnValues
* Statistics.corr
* params
* added doc to w() since Java docs do not inherit doc
* removed non-Java-friendly w() from StringArrayParam and DoubleArrayParam
* made DoubleArrayParam Java-friendly w() actually Java-friendly
I generated the doc and verified all changes.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #6562 from jkbradley/java-api-1.4 and squashes the following commits:
c16821b [Joseph K. Bradley] Small fixes based on code review.
d955581 [Joseph K. Bradley] unit test fixes
29b6b0d [Joseph K. Bradley] small fixes
fe6dcfe [Joseph K. Bradley] Added several Java-friendly APIs + unit tests: NaiveBayes, GaussianMixture, LDA, StreamingKMeans, Statistics.corr, params
(cherry picked from commit 20a26b595c
)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
This commit is contained in:
parent
1f90a06bda
commit
bfab61f39c
|
@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** Creates a param pair with the given value (for Java). */
|
||||||
* Creates a param pair with the given value (for Java).
|
|
||||||
*/
|
|
||||||
def w(value: T): ParamPair[T] = this -> value
|
def w(value: T): ParamPair[T] = this -> value
|
||||||
|
|
||||||
/**
|
/** Creates a param pair with the given value (for Scala). */
|
||||||
* Creates a param pair with the given value (for Scala).
|
|
||||||
*/
|
|
||||||
def ->(value: T): ParamPair[T] = ParamPair(this, value)
|
def ->(value: T): ParamPair[T] = ParamPair(this, value)
|
||||||
|
|
||||||
override final def toString: String = s"${parent}__$name"
|
override final def toString: String = s"${parent}__$name"
|
||||||
|
@ -190,6 +186,7 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
|
||||||
|
|
||||||
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
override def w(value: Double): ParamPair[Double] = super.w(value)
|
override def w(value: Double): ParamPair[Double] = super.w(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,6 +206,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
|
||||||
|
|
||||||
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
override def w(value: Int): ParamPair[Int] = super.w(value)
|
override def w(value: Int): ParamPair[Int] = super.w(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,6 +226,7 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
|
||||||
|
|
||||||
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
override def w(value: Float): ParamPair[Float] = super.w(value)
|
override def w(value: Float): ParamPair[Float] = super.w(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,6 +246,7 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
|
||||||
|
|
||||||
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
override def w(value: Long): ParamPair[Long] = super.w(value)
|
override def w(value: Long): ParamPair[Long] = super.w(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,6 +260,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
|
||||||
|
|
||||||
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
|
||||||
|
|
||||||
|
/** Creates a param pair with the given value (for Java). */
|
||||||
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
|
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,8 +275,6 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
|
||||||
def this(parent: Params, name: String, doc: String) =
|
def this(parent: Params, name: String, doc: String) =
|
||||||
this(parent, name, doc, ParamValidators.alwaysTrue)
|
this(parent, name, doc, ParamValidators.alwaysTrue)
|
||||||
|
|
||||||
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
|
|
||||||
|
|
||||||
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
|
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
|
||||||
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
|
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
|
||||||
}
|
}
|
||||||
|
@ -291,10 +290,9 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
|
||||||
def this(parent: Params, name: String, doc: String) =
|
def this(parent: Params, name: String, doc: String) =
|
||||||
this(parent, name, doc, ParamValidators.alwaysTrue)
|
this(parent, name, doc, ParamValidators.alwaysTrue)
|
||||||
|
|
||||||
override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)
|
|
||||||
|
|
||||||
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
|
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
|
||||||
def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
|
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
|
||||||
|
w(value.asScala.map(_.asInstanceOf[Double]).toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
|
||||||
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
|
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
|
import org.apache.spark.api.java.JavaRDD
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
@ -188,6 +189,9 @@ class GaussianMixture private (
|
||||||
new GaussianMixtureModel(weights, gaussians)
|
new GaussianMixtureModel(weights, gaussians)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Java-friendly version of [[run()]] */
|
||||||
|
def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
|
||||||
|
|
||||||
/** Average of dense breeze vectors */
|
/** Average of dense breeze vectors */
|
||||||
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
|
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
|
||||||
val v = BDV.zeros[Double](x(0).length)
|
val v = BDV.zeros[Double](x(0).length)
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
|
import org.apache.spark.api.java.JavaRDD
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
|
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
|
||||||
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
||||||
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
|
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
|
||||||
|
@ -46,7 +47,7 @@ import org.apache.spark.sql.{SQLContext, Row}
|
||||||
@Experimental
|
@Experimental
|
||||||
class GaussianMixtureModel(
|
class GaussianMixtureModel(
|
||||||
val weights: Array[Double],
|
val weights: Array[Double],
|
||||||
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
|
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
|
||||||
|
|
||||||
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
|
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
|
||||||
|
|
||||||
|
@ -65,6 +66,10 @@ class GaussianMixtureModel(
|
||||||
responsibilityMatrix.map(r => r.indexOf(r.max))
|
responsibilityMatrix.map(r => r.indexOf(r.max))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Java-friendly version of [[predict()]] */
|
||||||
|
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
|
||||||
|
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given the input vectors, return the membership value of each vector
|
* Given the input vectors, return the membership value of each vector
|
||||||
* to all mixture components.
|
* to all mixture components.
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
|
||||||
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
|
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
|
import org.apache.spark.api.java.JavaPairRDD
|
||||||
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
|
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
|
||||||
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
|
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
@ -345,6 +346,11 @@ class DistributedLDAModel private (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Java-friendly version of [[topicDistributions]] */
|
||||||
|
def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
|
||||||
|
JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
|
||||||
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
|
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,10 @@ import scala.reflect.ClassTag
|
||||||
|
|
||||||
import org.apache.spark.Logging
|
import org.apache.spark.Logging
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext._
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
|
||||||
import org.apache.spark.streaming.dstream.DStream
|
import org.apache.spark.streaming.dstream.DStream
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
import org.apache.spark.util.random.XORShiftRandom
|
import org.apache.spark.util.random.XORShiftRandom
|
||||||
|
@ -234,6 +236,9 @@ class StreamingKMeans(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Java-friendly version of `trainOn`. */
|
||||||
|
def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use the clustering model to make predictions on batches of data from a DStream.
|
* Use the clustering model to make predictions on batches of data from a DStream.
|
||||||
*
|
*
|
||||||
|
@ -245,6 +250,11 @@ class StreamingKMeans(
|
||||||
data.map(model.predict)
|
data.map(model.predict)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Java-friendly version of `predictOn`. */
|
||||||
|
def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
|
||||||
|
JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use the model to make predictions on the values of a DStream and carry over its keys.
|
* Use the model to make predictions on the values of a DStream and carry over its keys.
|
||||||
*
|
*
|
||||||
|
@ -257,6 +267,14 @@ class StreamingKMeans(
|
||||||
data.mapValues(model.predict)
|
data.mapValues(model.predict)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Java-friendly version of `predictOnValues`. */
|
||||||
|
def predictOnValues[K](
|
||||||
|
data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
|
||||||
|
implicit val tag = fakeClassTag[K]
|
||||||
|
JavaPairDStream.fromPairDStream(
|
||||||
|
predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
|
||||||
|
}
|
||||||
|
|
||||||
/** Check whether cluster centers have been initialized. */
|
/** Check whether cluster centers have been initialized. */
|
||||||
private[this] def assertInitialized(): Unit = {
|
private[this] def assertInitialized(): Unit = {
|
||||||
if (model.clusterCenters == null) {
|
if (model.clusterCenters == null) {
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.mllib.stat
|
package org.apache.spark.mllib.stat
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
|
import org.apache.spark.api.java.JavaRDD
|
||||||
import org.apache.spark.mllib.linalg.distributed.RowMatrix
|
import org.apache.spark.mllib.linalg.distributed.RowMatrix
|
||||||
import org.apache.spark.mllib.linalg.{Matrix, Vector}
|
import org.apache.spark.mllib.linalg.{Matrix, Vector}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
@ -80,6 +81,10 @@ object Statistics {
|
||||||
*/
|
*/
|
||||||
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
|
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
|
||||||
|
|
||||||
|
/** Java-friendly version of [[corr()]] */
|
||||||
|
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
|
||||||
|
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the correlation for the input RDDs using the specified method.
|
* Compute the correlation for the input RDDs using the specified method.
|
||||||
* Methods currently supported: `pearson` (default), `spearman`.
|
* Methods currently supported: `pearson` (default), `spearman`.
|
||||||
|
@ -96,6 +101,10 @@ object Statistics {
|
||||||
*/
|
*/
|
||||||
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
|
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
|
||||||
|
|
||||||
|
/** Java-friendly version of [[corr()]] */
|
||||||
|
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
|
||||||
|
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
|
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
|
||||||
* expected distribution.
|
* expected distribution.
|
||||||
|
|
|
@ -50,6 +50,7 @@ public class JavaParamsSuite {
|
||||||
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
|
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
|
||||||
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
|
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
|
||||||
Assert.assertEquals(testParams.getMyStringParam(), "a");
|
Assert.assertEquals(testParams.getMyStringParam(), "a");
|
||||||
|
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -51,7 +51,8 @@ public class JavaTestParams extends JavaParams {
|
||||||
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
|
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
|
||||||
|
|
||||||
public JavaTestParams setMyIntParam(int value) {
|
public JavaTestParams setMyIntParam(int value) {
|
||||||
set(myIntParam_, value); return this;
|
set(myIntParam_, value);
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private DoubleParam myDoubleParam_;
|
private DoubleParam myDoubleParam_;
|
||||||
|
@ -60,7 +61,8 @@ public class JavaTestParams extends JavaParams {
|
||||||
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
|
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
|
||||||
|
|
||||||
public JavaTestParams setMyDoubleParam(double value) {
|
public JavaTestParams setMyDoubleParam(double value) {
|
||||||
set(myDoubleParam_, value); return this;
|
set(myDoubleParam_, value);
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Param<String> myStringParam_;
|
private Param<String> myStringParam_;
|
||||||
|
@ -69,7 +71,18 @@ public class JavaTestParams extends JavaParams {
|
||||||
public String getMyStringParam() { return getOrDefault(myStringParam_); }
|
public String getMyStringParam() { return getOrDefault(myStringParam_); }
|
||||||
|
|
||||||
public JavaTestParams setMyStringParam(String value) {
|
public JavaTestParams setMyStringParam(String value) {
|
||||||
set(myStringParam_, value); return this;
|
set(myStringParam_, value);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private DoubleArrayParam myDoubleArrayParam_;
|
||||||
|
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
|
||||||
|
|
||||||
|
public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
|
||||||
|
|
||||||
|
public JavaTestParams setMyDoubleArrayParam(double[] value) {
|
||||||
|
set(myDoubleArrayParam_, value);
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void init() {
|
private void init() {
|
||||||
|
@ -79,8 +92,14 @@ public class JavaTestParams extends JavaParams {
|
||||||
List<String> validStrings = Lists.newArrayList("a", "b");
|
List<String> validStrings = Lists.newArrayList("a", "b");
|
||||||
myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
|
myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
|
||||||
ParamValidators.inArray(validStrings));
|
ParamValidators.inArray(validStrings));
|
||||||
setDefault(myIntParam_, 1);
|
myDoubleArrayParam_ =
|
||||||
setDefault(myDoubleParam_, 0.5);
|
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
|
||||||
|
|
||||||
|
setDefault(myIntParam(), 1);
|
||||||
|
setDefault(myIntParam().w(1));
|
||||||
|
setDefault(myDoubleParam(), 0.5);
|
||||||
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
|
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
|
||||||
|
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
|
||||||
|
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.apache.spark.ml.classification;
|
package org.apache.spark.mllib.classification;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -28,7 +28,6 @@ import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
|
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
|
@ -0,0 +1,64 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.mllib.clustering;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
|
||||||
|
public class JavaGaussianMixtureSuite implements Serializable {
|
||||||
|
private transient JavaSparkContext sc;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
sc = new JavaSparkContext("local", "JavaGaussianMixture");
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void tearDown() {
|
||||||
|
sc.stop();
|
||||||
|
sc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void runGaussianMixture() {
|
||||||
|
List<Vector> points = Lists.newArrayList(
|
||||||
|
Vectors.dense(1.0, 2.0, 6.0),
|
||||||
|
Vectors.dense(1.0, 3.0, 0.0),
|
||||||
|
Vectors.dense(1.0, 4.0, 6.0)
|
||||||
|
);
|
||||||
|
|
||||||
|
JavaRDD<Vector> data = sc.parallelize(points, 2);
|
||||||
|
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
|
||||||
|
.run(data);
|
||||||
|
assertEquals(model.gaussians().length, 2);
|
||||||
|
JavaRDD<Integer> predictions = model.predict(data);
|
||||||
|
predictions.first();
|
||||||
|
}
|
||||||
|
}
|
|
@ -107,6 +107,10 @@ public class JavaLDASuite implements Serializable {
|
||||||
// Check: log probabilities
|
// Check: log probabilities
|
||||||
assert(model.logLikelihood() < 0.0);
|
assert(model.logLikelihood() < 0.0);
|
||||||
assert(model.logPrior() < 0.0);
|
assert(model.logPrior() < 0.0);
|
||||||
|
|
||||||
|
// Check: topic distributions
|
||||||
|
JavaPairRDD<Long, Vector> topicDistributions = model.javaTopicDistributions();
|
||||||
|
assertEquals(topicDistributions.count(), corpus.count());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.mllib.clustering;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import scala.Tuple2;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.apache.spark.streaming.JavaTestUtils.*;
|
||||||
|
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
import org.apache.spark.streaming.Duration;
|
||||||
|
import org.apache.spark.streaming.api.java.JavaDStream;
|
||||||
|
import org.apache.spark.streaming.api.java.JavaPairDStream;
|
||||||
|
import org.apache.spark.streaming.api.java.JavaStreamingContext;
|
||||||
|
|
||||||
|
public class JavaStreamingKMeansSuite implements Serializable {
|
||||||
|
|
||||||
|
protected transient JavaStreamingContext ssc;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
SparkConf conf = new SparkConf()
|
||||||
|
.setMaster("local[2]")
|
||||||
|
.setAppName("test")
|
||||||
|
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
|
||||||
|
ssc = new JavaStreamingContext(conf, new Duration(1000));
|
||||||
|
ssc.checkpoint("checkpoint");
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void tearDown() {
|
||||||
|
ssc.stop();
|
||||||
|
ssc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void javaAPI() {
|
||||||
|
List<Vector> trainingBatch = Lists.newArrayList(
|
||||||
|
Vectors.dense(1.0),
|
||||||
|
Vectors.dense(0.0));
|
||||||
|
JavaDStream<Vector> training =
|
||||||
|
attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
|
||||||
|
List<Tuple2<Integer, Vector>> testBatch = Lists.newArrayList(
|
||||||
|
new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
|
||||||
|
new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
|
||||||
|
JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
|
||||||
|
attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
|
||||||
|
StreamingKMeans skmeans = new StreamingKMeans()
|
||||||
|
.setK(1)
|
||||||
|
.setDecayFactor(1.0)
|
||||||
|
.setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0});
|
||||||
|
skmeans.trainOn(training);
|
||||||
|
JavaPairDStream<Integer, Integer> prediction = skmeans.predictOnValues(test);
|
||||||
|
attachTestOutputStream(prediction.count());
|
||||||
|
runStreams(ssc, 2, 2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.mllib.stat;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
|
||||||
|
public class JavaStatisticsSuite implements Serializable {
|
||||||
|
private transient JavaSparkContext sc;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
sc = new JavaSparkContext("local", "JavaStatistics");
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void tearDown() {
|
||||||
|
sc.stop();
|
||||||
|
sc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCorr() {
|
||||||
|
JavaRDD<Double> x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0));
|
||||||
|
JavaRDD<Double> y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3));
|
||||||
|
|
||||||
|
Double corr1 = Statistics.corr(x, y);
|
||||||
|
Double corr2 = Statistics.corr(x, y, "pearson");
|
||||||
|
// Check default method
|
||||||
|
assertEquals(corr1, corr2);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue