[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites

## What changes were proposed in this pull request?
Use SparkSession instead of SQLContext in Scala/Java TestSuites
as this PR already very big working Python TestSuites in a diff PR.

## How was this patch tested?
Existing tests

Author: Sandeep Singh <sandeep@techaddict.me>

Closes #12907 from techaddict/SPARK-15037.
This commit is contained in:
Sandeep Singh 2016-05-10 11:17:47 -07:00 committed by Andrew Or
parent bcfee153b1
commit ed0b4070fb
224 changed files with 2916 additions and 2593 deletions

View file

@ -17,18 +17,18 @@
package org.apache.spark.ml; package org.apache.spark.ml;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler; import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.sql.SQLContext; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
/** /**
@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene
*/ */
public class JavaPipelineSuite { public class JavaPipelineSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset; private transient Dataset<Row> dataset;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaPipelineSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaPipelineSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
JavaRDD<LabeledPoint> points = JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = jsql.createDataFrame(points, LabeledPoint.class); dataset = spark.createDataFrame(points, LabeledPoint.class);
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -63,10 +66,10 @@ public class JavaPipelineSuite {
LogisticRegression lr = new LogisticRegression() LogisticRegression lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures"); .setFeaturesCol("scaledFeatures");
Pipeline pipeline = new Pipeline() Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {scaler, lr}); .setStages(new PipelineStage[]{scaler, lr});
PipelineModel model = pipeline.fit(dataset); PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction"); model.transform(dataset).registerTempTable("prediction");
Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList(); predictions.collectAsList();
} }
} }

View file

@ -17,8 +17,8 @@
package org.apache.spark.ml.attribute; package org.apache.spark.ml.attribute;
import org.junit.Test;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test;
public class JavaAttributeSuite { public class JavaAttributeSuite {

View file

@ -21,8 +21,6 @@ import java.io.Serializable;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeClassifierSuite implements Serializable { public class JavaDecisionTreeClassifierSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize( JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
.setCacheNodeIds(false) .setCacheNodeIds(false)
.setCheckpointInterval(10) .setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: DecisionTreeClassifier.supportedImpurities()) { for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
dt.setImpurity(impurity); dt.setImpurity(impurity);
} }
DecisionTreeClassificationModel model = dt.fit(dataFrame); DecisionTreeClassificationModel model = dt.fit(dataFrame);

View file

@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaGBTClassifierSuite implements Serializable { public class JavaGBTClassifierSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaGBTClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize( JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable {
.setMaxIter(3) .setMaxIter(3)
.setStepSize(0.1) .setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String lossType: GBTClassifier.supportedLossTypes()) { for (String lossType : GBTClassifier.supportedLossTypes()) {
rf.setLossType(lossType); rf.setLossType(lossType);
} }
GBTClassificationModel model = rf.fit(dataFrame); GBTClassificationModel model = rf.fit(dataFrame);

View file

@ -27,18 +27,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable { public class JavaLogisticRegressionSuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset; private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD; private transient JavaRDD<LabeledPoint> datasetRDD;
@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2); datasetRDD = jsc.parallelize(points, 2);
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset"); dataset.registerTempTable("dataset");
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label"); Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset); LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction"); model.transform(dataset).registerTempTable("prediction");
Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList(); predictions.collectAsList();
// Check defaults // Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps); Assert.assertEquals(0.5, model.getThreshold(), eps);
@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked. // Modify model params, and check that the params worked.
model.setThreshold(1.0); model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero"); model.transform(dataset).registerTempTable("predAllZero");
Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) { for (Row r : predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps); Assert.assertEquals(0.0, r.getDouble(0), eps);
} }
// Call transform with params, and check that the params worked. // Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero"); .registerTempTable("predNotAllZero");
Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false; boolean foundNonZero = false;
for (Row r: predNotAllZero.collectAsList()) { for (Row r : predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true; if (r.getDouble(0) != 0.0) foundNonZero = true;
} }
Assert.assertTrue(foundNonZero); Assert.assertTrue(foundNonZero);
// Call fit() with new params, and check as many params as we can. // Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent(); LogisticRegression parent2 = (LogisticRegression) model2.parent();
Assert.assertEquals(5, parent2.getMaxIter()); Assert.assertEquals(5, parent2.getMaxIter());
Assert.assertEquals(0.1, parent2.getRegParam(), eps); Assert.assertEquals(0.1, parent2.getRegParam(), eps);
@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses()); Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed"); model.transform(dataset).registerTempTable("transformed");
Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
for (Row row: trans1.collectAsList()) { for (Row row : trans1.collectAsList()) {
Vector raw = (Vector)row.get(0); Vector raw = (Vector) row.get(0);
Vector prob = (Vector)row.get(1); Vector prob = (Vector) row.get(1);
Assert.assertEquals(raw.size(), 2); Assert.assertEquals(raw.size(), 2);
Assert.assertEquals(prob.size(), 2); Assert.assertEquals(prob.size(), 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
} }
Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
for (Row row: trans2.collectAsList()) { for (Row row : trans2.collectAsList()) {
double pred = row.getDouble(0); double pred = row.getDouble(0);
Vector prob = (Vector)row.get(1); Vector prob = (Vector) row.get(1);
double probOfPred = prob.apply((int)pred); double probOfPred = prob.apply((int) pred);
for (int i = 0; i < prob.size(); ++i) { for (int i = 0; i < prob.size(); ++i) {
Assert.assertTrue(probOfPred >= prob.apply(i)); Assert.assertTrue(probOfPred >= prob.apply(i));
} }

View file

@ -26,49 +26,49 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaMultilayerPerceptronClassifierSuite implements Serializable { public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext sqlContext;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); spark = SparkSession.builder()
sqlContext = new SQLContext(jsc); .master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
sqlContext = null;
} }
@Test @Test
public void testMLPC() { public void testMLPC() {
Dataset<Row> dataFrame = sqlContext.createDataFrame( List<LabeledPoint> data = Arrays.asList(
jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), );
LabeledPoint.class); Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
.setLayers(new int[] {2, 5, 2}) .setLayers(new int[]{2, 5, 2})
.setBlockSize(1) .setBlockSize(1)
.setSeed(123L) .setSeed(123L)
.setMaxIter(100); .setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame); Dataset<Row> result = model.transform(dataFrame);
List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList(); List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
for (Row r: predictionAndLabels) { for (Row r : predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
} }
} }

View file

@ -26,13 +26,12 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaNaiveBayesSuite implements Serializable { public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
public void validatePrediction(Dataset<Row> predictionAndLabels) { public void validatePrediction(Dataset<Row> predictionAndLabels) {
@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty()) new StructField("features", new VectorUDT(), false, Metadata.empty())
}); });
Dataset<Row> dataset = jsql.createDataFrame(data, schema); Dataset<Row> dataset = spark.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset); NaiveBayesModel model = nb.fit(dataset);

View file

@ -20,7 +20,6 @@ package org.apache.spark.ml.classification;
import java.io.Serializable; import java.io.Serializable;
import java.util.List; import java.util.List;
import org.apache.spark.sql.Row;
import scala.collection.JavaConverters; import scala.collection.JavaConverters;
import org.junit.After; import org.junit.After;
@ -30,56 +29,61 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
public class JavaOneVsRestSuite implements Serializable { public class JavaOneVsRestSuite implements Serializable {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql; private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset; private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD; private transient JavaRDD<LabeledPoint> datasetRDD;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
int nPoints = 3; .appName("JavaLOneVsRestSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
// The following coefficients and xMean/xVariance are computed from iris dataset with int nPoints = 3;
// lambda=0.2.
// As a result, we are drawing samples from probability distribution of an actual model.
double[] coefficients = {
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
double[] xMean = {5.843, 3.057, 3.758, 1.199}; // The following coefficients and xMean/xVariance are computed from iris dataset with
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; // lambda=0.2.
List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter( // As a result, we are drawing samples from probability distribution of an actual model.
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) double[] coefficients = {
).asJava(); -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
datasetRDD = jsc.parallelize(points, 2); -0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
}
@After double[] xMean = {5.843, 3.057, 3.758, 1.199};
public void tearDown() { double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
jsc.stop(); List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
jsc = null; generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
} ).asJava();
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
}
@Test @After
public void oneVsRestDefaultParams() { public void tearDown() {
OneVsRest ova = new OneVsRest(); spark.stop();
ova.setClassifier(new LogisticRegression()); spark = null;
Assert.assertEquals(ova.getLabelCol() , "label"); }
Assert.assertEquals(ova.getPredictionCol() , "prediction");
OneVsRestModel ovaModel = ova.fit(dataset); @Test
Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction"); public void oneVsRestDefaultParams() {
predictions.collectAsList(); OneVsRest ova = new OneVsRest();
Assert.assertEquals(ovaModel.getLabelCol(), "label"); ova.setClassifier(new LogisticRegression());
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); Assert.assertEquals(ova.getLabelCol(), "label");
} Assert.assertEquals(ova.getPredictionCol(), "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol(), "prediction");
}
} }

View file

@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaRandomForestClassifierSuite implements Serializable { public class JavaRandomForestClassifierSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaRandomForestClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize( JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable {
.setSeed(1234) .setSeed(1234)
.setNumTrees(3) .setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: RandomForestClassifier.supportedImpurities()) { for (String impurity : RandomForestClassifier.supportedImpurities()) {
rf.setImpurity(impurity); rf.setImpurity(impurity);
} }
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy); rf.setFeatureSubsetStrategy(featureSubsetStrategy);
} }
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
for (String strategy: realStrategies) { for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy); rf.setFeatureSubsetStrategy(strategy);
} }
String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
for (String strategy: integerStrategies) { for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy); rf.setFeatureSubsetStrategy(strategy);
} }
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
for (String strategy: invalidStrategies) { for (String strategy : invalidStrategies) {
try { try {
rf.setFeatureSubsetStrategy(strategy); rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies"); Assert.fail("Expected exception to be thrown for invalid strategies");

View file

@ -21,37 +21,37 @@ import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.JavaSparkContext; import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable { public class JavaKMeansSuite implements Serializable {
private transient int k = 5; private transient int k = 5;
private transient JavaSparkContext sc;
private transient Dataset<Row> dataset; private transient Dataset<Row> dataset;
private transient SQLContext sql; private transient SparkSession spark;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaKMeansSuite"); spark = SparkSession.builder()
sql = new SQLContext(sc); .master("local")
.appName("JavaKMeansSuite")
dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); .getOrCreate();
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable {
Dataset<Row> transformed = model.transform(dataset); Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns()); List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction"); List<String> expectedColumns = Arrays.asList("features", "prediction");
for (String column: expectedColumns) { for (String column : expectedColumns) {
assertTrue(columns.contains(column)); assertTrue(columns.contains(column));
} }
} }

View file

@ -25,40 +25,40 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
public class JavaBucketizerSuite { public class JavaBucketizerSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaBucketizerSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
public void bucketizerTest() { public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5}; double[] splits = {-0.5, 0.0, 0.5};
StructType schema = new StructType(new StructField[] { StructType schema = new StructType(new StructField[]{
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
}); });
Dataset<Row> dataset = jsql.createDataFrame( Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList( Arrays.asList(
RowFactory.create(-0.5), RowFactory.create(-0.5),
RowFactory.create(-0.3), RowFactory.create(-0.3),

View file

@ -21,43 +21,44 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
public class JavaDCTSuite { public class JavaDCTSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaDCTSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaDCTSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
public void javaCompatibilityTest() { public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D}; double[] input = new double[]{1D, 2D, 3D, 4D};
Dataset<Row> dataset = jsql.createDataFrame( Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))), Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{ new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty()) new StructField("vec", (new VectorUDT()), false, Metadata.empty())

View file

@ -25,12 +25,11 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaHashingTFSuite { public class JavaHashingTFSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaHashingTFSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaHashingTFSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -65,7 +65,7 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
}); });
Dataset<Row> sentenceData = jsql.createDataFrame(data, schema); Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer() Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence") .setInputCol("sentence")
.setOutputCol("words"); .setOutputCol("words");

View file

@ -23,27 +23,30 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaNormalizerSuite { public class JavaNormalizerSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaNormalizerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -54,7 +57,7 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
)); ));
Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer() Normalizer normalizer = new Normalizer()
.setInputCol("features") .setInputCol("features")
.setOutputCol("normFeatures"); .setOutputCol("normFeatures");

View file

@ -28,31 +28,34 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaPCASuite implements Serializable { public class JavaPCASuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaPCASuite"); spark = SparkSession.builder()
sqlContext = new SQLContext(jsc); .master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
public static class VectorPair implements Serializable { public static class VectorPair implements Serializable {
@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable {
} }
); );
Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); Dataset<Row> df = spark.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA() PCAModel pca = new PCA()
.setInputCol("features") .setInputCol("features")
.setOutputCol("pca_features") .setOutputCol("pca_features")

View file

@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
public class JavaPolynomialExpansionSuite { public class JavaPolynomialExpansionSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaPolynomialExpansionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite {
) )
); );
StructType schema = new StructType(new StructField[] { StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()),
new StructField("expected", new VectorUDT(), false, Metadata.empty()) new StructField("expected", new VectorUDT(), false, Metadata.empty())
}); });
Dataset<Row> dataset = jsql.createDataFrame(data, schema); Dataset<Row> dataset = spark.createDataFrame(data, schema);
List<Row> pairs = polyExpansion.transform(dataset) List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected") .select("polyFeatures", "expected")
.collectAsList(); .collectAsList();
for (Row r : pairs) { for (Row r : pairs) {
double[] polyFeatures = ((Vector)r.get(0)).toArray(); double[] polyFeatures = ((Vector) r.get(0)).toArray();
double[] expected = ((Vector)r.get(1)).toArray(); double[] expected = ((Vector) r.get(1)).toArray();
Assert.assertArrayEquals(polyFeatures, expected, 1e-1); Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
} }
} }

View file

@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaStandardScalerSuite { public class JavaStandardScalerSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaStandardScalerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -54,7 +57,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
); );
Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class); VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler() StandardScaler scaler = new StandardScaler()
.setInputCol("features") .setInputCol("features")

View file

@ -24,11 +24,10 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaStopWordsRemoverSuite { public class JavaStopWordsRemoverSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaStopWordsRemoverSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite {
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
); );
StructType schema = new StructType(new StructField[] { StructType schema = new StructType(new StructField[]{
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
Metadata.empty()) Metadata.empty())
}); });
Dataset<Row> dataset = jsql.createDataFrame(data, schema); Dataset<Row> dataset = spark.createDataFrame(data, schema);
remover.transform(dataset).collect(); remover.transform(dataset).collect();
} }

View file

@ -25,40 +25,42 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*; import static org.apache.spark.sql.types.DataTypes.*;
public class JavaStringIndexerSuite { public class JavaStringIndexerSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext sqlContext;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); SparkConf sparkConf = new SparkConf();
sqlContext = new SQLContext(jsc); sparkConf.setMaster("local");
sparkConf.setAppName("JavaStringIndexerSuite");
spark = SparkSession.builder().config(sparkConf).getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
sqlContext = null; spark = null;
} }
@Test @Test
public void testStringIndexer() { public void testStringIndexer() {
StructType schema = createStructType(new StructField[] { StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false), createStructField("id", IntegerType, false),
createStructField("label", StringType, false) createStructField("label", StringType, false)
}); });
List<Row> data = Arrays.asList( List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
Dataset<Row> dataset = sqlContext.createDataFrame(data, schema); Dataset<Row> dataset = spark.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer() StringIndexer indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
@ -70,7 +72,9 @@ public class JavaStringIndexerSuite {
output.orderBy("id").select("id", "labelIndex").collectAsList()); output.orderBy("id").select("id", "labelIndex").collectAsList());
} }
/** An alias for RowFactory.create. */ /**
* An alias for RowFactory.create.
*/
private Row cr(Object... values) { private Row cr(Object... values) {
return RowFactory.create(values); return RowFactory.create(values);
} }

View file

@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaTokenizerSuite { public class JavaTokenizerSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaTokenizerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -59,10 +62,10 @@ public class JavaTokenizerSuite {
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList( JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"})
)); ));
Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class);
List<Row> pairs = myRegExTokenizer.transform(dataset) List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens") .select("tokens", "wantedTokens")

View file

@ -24,36 +24,39 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*; import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*; import static org.apache.spark.sql.types.DataTypes.*;
public class JavaVectorAssemblerSuite { public class JavaVectorAssemblerSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext sqlContext;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); SparkConf sparkConf = new SparkConf();
sqlContext = new SQLContext(jsc); sparkConf.setMaster("local");
sparkConf.setAppName("JavaVectorAssemblerSuite");
spark = SparkSession.builder().config(sparkConf).getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
public void testVectorAssembler() { public void testVectorAssembler() {
StructType schema = createStructType(new StructField[] { StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false), createStructField("id", IntegerType, false),
createStructField("x", DoubleType, false), createStructField("x", DoubleType, false),
createStructField("y", new VectorUDT(), false), createStructField("y", new VectorUDT(), false),
@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite {
}); });
Row row = RowFactory.create( Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a", 0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler() VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"}) .setInputCols(new String[]{"x", "y", "z", "n"})
.setOutputCol("features"); .setOutputCol("features");
Dataset<Row> output = assembler.transform(dataset); Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals( Assert.assertEquals(
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0)); output.select("features").first().<Vector>getAs(0));
} }
} }

View file

@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
public class JavaVectorIndexerSuite implements Serializable { public class JavaVectorIndexerSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaVectorIndexerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 3.0)),
new FeatureData(Vectors.dense(1.0, 4.0)) new FeatureData(Vectors.dense(1.0, 4.0))
); );
SQLContext sqlContext = new SQLContext(sc); Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class);
Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer() VectorIndexer indexer = new VectorIndexer()
.setInputCol("features") .setInputCol("features")
.setOutputCol("indexed") .setOutputCol("indexed")

View file

@ -25,7 +25,6 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.ml.attribute.NumericAttribute;
@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
public class JavaVectorSlicerSuite { public class JavaVectorSlicerSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext jsql;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaVectorSlicerSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite {
); );
Dataset<Row> dataset = Dataset<Row> dataset =
jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); spark.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer() VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features"); .setInputCol("userFeatures").setOutputCol("features");

View file

@ -24,28 +24,28 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*; import org.apache.spark.sql.types.*;
public class JavaWord2VecSuite { public class JavaWord2VecSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext sqlContext;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); spark = SparkSession.builder()
sqlContext = new SQLContext(jsc); .master("local")
.appName("JavaWord2VecSuite")
.getOrCreate();
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
}); });
Dataset<Row> documentDF = sqlContext.createDataFrame( Dataset<Row> documentDF = spark.createDataFrame(
Arrays.asList( Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@ -68,8 +68,8 @@ public class JavaWord2VecSuite {
Word2VecModel model = word2Vec.fit(documentDF); Word2VecModel model = word2Vec.fit(documentDF);
Dataset<Row> result = model.transform(documentDF); Dataset<Row> result = model.transform(documentDF);
for (Row r: result.select("result").collectAsList()) { for (Row r : result.select("result").collectAsList()) {
double[] polyFeatures = ((Vector)r.get(0)).toArray(); double[] polyFeatures = ((Vector) r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3); Assert.assertEquals(polyFeatures.length, 3);
} }
} }

View file

@ -25,23 +25,29 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
/** /**
* Test Param and related classes in Java * Test Param and related classes in Java
*/ */
public class JavaParamsSuite { public class JavaParamsSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaParamsSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaParamsSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
@Test @Test
@ -51,7 +57,7 @@ public class JavaParamsSuite {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a"); Assert.assertEquals(testParams.getMyStringParam(), "a");
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0);
} }
@Test @Test

View file

@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams {
} }
private IntParam myIntParam_; private IntParam myIntParam_;
public IntParam myIntParam() { return myIntParam_; }
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public IntParam myIntParam() {
return myIntParam_;
}
public int getMyIntParam() {
return (Integer) getOrDefault(myIntParam_);
}
public JavaTestParams setMyIntParam(int value) { public JavaTestParams setMyIntParam(int value) {
set(myIntParam_, value); set(myIntParam_, value);
@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams {
} }
private DoubleParam myDoubleParam_; private DoubleParam myDoubleParam_;
public DoubleParam myDoubleParam() { return myDoubleParam_; }
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public DoubleParam myDoubleParam() {
return myDoubleParam_;
}
public double getMyDoubleParam() {
return (Double) getOrDefault(myDoubleParam_);
}
public JavaTestParams setMyDoubleParam(double value) { public JavaTestParams setMyDoubleParam(double value) {
set(myDoubleParam_, value); set(myDoubleParam_, value);
@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams {
} }
private Param<String> myStringParam_; private Param<String> myStringParam_;
public Param<String> myStringParam() { return myStringParam_; }
public String getMyStringParam() { return getOrDefault(myStringParam_); } public Param<String> myStringParam() {
return myStringParam_;
}
public String getMyStringParam() {
return getOrDefault(myStringParam_);
}
public JavaTestParams setMyStringParam(String value) { public JavaTestParams setMyStringParam(String value) {
set(myStringParam_, value); set(myStringParam_, value);
@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams {
} }
private DoubleArrayParam myDoubleArrayParam_; private DoubleArrayParam myDoubleArrayParam_;
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } public DoubleArrayParam myDoubleArrayParam() {
return myDoubleArrayParam_;
}
public double[] getMyDoubleArrayParam() {
return getOrDefault(myDoubleArrayParam_);
}
public JavaTestParams setMyDoubleArrayParam(double[] value) { public JavaTestParams setMyDoubleArrayParam(double[] value) {
set(myDoubleArrayParam_, value); set(myDoubleArrayParam_, value);
@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams {
setDefault(myIntParam(), 1); setDefault(myIntParam(), 1);
setDefault(myDoubleParam(), 0.5); setDefault(myDoubleParam(), 0.5);
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0});
} }
@Override @Override

View file

@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeRegressorSuite implements Serializable { public class JavaDecisionTreeRegressorSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeRegressorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize( JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
.setCacheNodeIds(false) .setCacheNodeIds(false)
.setCheckpointInterval(10) .setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: DecisionTreeRegressor.supportedImpurities()) { for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
dt.setImpurity(impurity); dt.setImpurity(impurity);
} }
DecisionTreeRegressionModel model = dt.fit(dataFrame); DecisionTreeRegressionModel model = dt.fit(dataFrame);

View file

@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaGBTRegressorSuite implements Serializable { public class JavaGBTRegressorSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaGBTRegressorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize( JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable {
.setMaxIter(3) .setMaxIter(3)
.setStepSize(0.1) .setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String lossType: GBTRegressor.supportedLossTypes()) { for (String lossType : GBTRegressor.supportedLossTypes()) {
rf.setLossType(lossType); rf.setLossType(lossType);
} }
GBTRegressionModel model = rf.fit(dataFrame); GBTRegressionModel model = rf.fit(dataFrame);

View file

@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
.generateLogisticInputAsList;
public class JavaLinearRegressionSuite implements Serializable { public class JavaLinearRegressionSuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset; private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD; private transient JavaRDD<LabeledPoint> datasetRDD;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaLinearRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2); datasetRDD = jsc.parallelize(points, 2);
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset"); dataset.registerTempTable("dataset");
} }
@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver()); assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset); LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction"); model.transform(dataset).registerTempTable("prediction");
Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction"); Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction");
predictions.collect(); predictions.collect();
// Check defaults // Check defaults
assertEquals("features", model.getFeaturesCol()); assertEquals("features", model.getFeaturesCol());
@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable {
public void linearRegressionWithSetters() { public void linearRegressionWithSetters() {
// Set params, train, and check as many params as we can. // Set params, train, and check as many params as we can.
LinearRegression lr = new LinearRegression() LinearRegression lr = new LinearRegression()
.setMaxIter(10) .setMaxIter(10)
.setRegParam(1.0).setSolver("l-bfgs"); .setRegParam(1.0).setSolver("l-bfgs");
LinearRegressionModel model = lr.fit(dataset); LinearRegressionModel model = lr.fit(dataset);
LinearRegression parent = (LinearRegression) model.parent(); LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter()); assertEquals(10, parent.getMaxIter());
@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can. // Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 = LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
LinearRegression parent2 = (LinearRegression) model2.parent(); LinearRegression parent2 = (LinearRegression) model2.parent();
assertEquals(5, parent2.getMaxIter()); assertEquals(5, parent2.getMaxIter());
assertEquals(0.1, parent2.getRegParam(), 0.0); assertEquals(0.1, parent2.getRegParam(), 0.0);

View file

@ -28,27 +28,33 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaRandomForestRegressorSuite implements Serializable { public class JavaRandomForestRegressorSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaRandomForestRegressorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize( JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable {
.setSeed(1234) .setSeed(1234)
.setNumTrees(3) .setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: RandomForestRegressor.supportedImpurities()) { for (String impurity : RandomForestRegressor.supportedImpurities()) {
rf.setImpurity(impurity); rf.setImpurity(impurity);
} }
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy); rf.setFeatureSubsetStrategy(featureSubsetStrategy);
} }
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
for (String strategy: realStrategies) { for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy); rf.setFeatureSubsetStrategy(strategy);
} }
String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
for (String strategy: integerStrategies) { for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy); rf.setFeatureSubsetStrategy(strategy);
} }
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
for (String strategy: invalidStrategies) { for (String strategy : invalidStrategies) {
try { try {
rf.setFeatureSubsetStrategy(strategy); rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies"); Assert.fail("Expected exception to be thrown for invalid strategies");

View file

@ -28,12 +28,11 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils; import org.apache.spark.util.Utils;
@ -41,16 +40,17 @@ import org.apache.spark.util.Utils;
* Test LibSVMRelation in Java. * Test LibSVMRelation in Java.
*/ */
public class JavaLibSVMRelationSuite { public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc; private transient SparkSession spark;
private transient SQLContext sqlContext;
private File tempDir; private File tempDir;
private String path; private String path;
@Before @Before
public void setUp() throws IOException { public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); spark = SparkSession.builder()
sqlContext = new SQLContext(jsc); .master("local")
.appName("JavaLibSVMRelationSuite")
.getOrCreate();
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
File file = new File(tempDir, "part-00000"); File file = new File(tempDir, "part-00000");
@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite {
@After @After
public void tearDown() { public void tearDown() {
jsc.stop(); spark.stop();
jsc = null; spark = null;
Utils.deleteRecursively(tempDir); Utils.deleteRecursively(tempDir);
} }
@Test @Test
public void verifyLibSVMDF() { public void verifyLibSVMDF() {
Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense")
.load(path); .load(path);
Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]); Assert.assertEquals("features", dataset.columns()[1]);

View file

@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaCrossValidatorSuite implements Serializable { public class JavaCrossValidatorSuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset; private transient Dataset<Row> dataset;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); spark = SparkSession.builder()
jsql = new SQLContext(jsc); .master("local")
.appName("JavaCrossValidatorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
} }
@After @After
@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable {
public void crossValidationWithLogisticRegression() { public void crossValidationWithLogisticRegression() {
LogisticRegression lr = new LogisticRegression(); LogisticRegression lr = new LogisticRegression();
ParamMap[] lrParamMaps = new ParamGridBuilder() ParamMap[] lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[] {0.001, 1000.0}) .addGrid(lr.regParam(), new double[]{0.001, 1000.0})
.addGrid(lr.maxIter(), new int[] {0, 10}) .addGrid(lr.maxIter(), new int[]{0, 10})
.build(); .build();
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
CrossValidator cv = new CrossValidator() CrossValidator cv = new CrossValidator()

View file

@ -37,4 +37,5 @@ object IdentifiableSuite {
class Test(override val uid: String) extends Identifiable { class Test(override val uid: String) extends Identifiable {
def this() = this(Identifiable.randomUID("test")) def this() = this(Identifiable.randomUID("test"))
} }
} }

View file

@ -27,31 +27,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils; import org.apache.spark.util.Utils;
public class JavaDefaultReadWriteSuite { public class JavaDefaultReadWriteSuite {
JavaSparkContext jsc = null; JavaSparkContext jsc = null;
SQLContext sqlContext = null; SparkSession spark = null;
File tempDir = null; File tempDir = null;
@Before @Before
public void setUp() { public void setUp() {
jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
SQLContext.clearActive(); SQLContext.clearActive();
sqlContext = new SQLContext(jsc); spark = SparkSession.builder()
SQLContext.setActive(sqlContext); .master("local[2]")
.appName("JavaDefaultReadWriteSuite")
.getOrCreate();
SQLContext.setActive(spark.wrapped());
tempDir = Utils.createTempDir( tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
} }
@After @After
public void tearDown() { public void tearDown() {
sqlContext = null;
SQLContext.clearActive(); SQLContext.clearActive();
if (jsc != null) { if (spark != null) {
jsc.stop(); spark.stop();
jsc = null; spark = null;
} }
Utils.deleteRecursively(tempDir); Utils.deleteRecursively(tempDir);
} }
@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite {
} catch (IOException e) { } catch (IOException e) {
// expected // expected
} }
instance.write().context(sqlContext).overwrite().save(outputPath); instance.write().context(spark.wrapped()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath); MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.", Assert.assertEquals("Params should be preserved.",

View file

@ -27,26 +27,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
public class JavaLogisticRegressionSuite implements Serializable { public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) { int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
int numAccurate = 0; int numAccurate = 0;
for (LabeledPoint point: validationData) { for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features()); Double prediction = model.predict(point.features());
if (prediction == point.label()) { if (prediction == point.label()) {
numAccurate++; numAccurate++;
@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 2.0; double A = 2.0;
double B = -1.5; double B = -1.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize( JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
lrImpl.setIntercept(true); lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0) lrImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0) .setRegParam(1.0)
.setNumIterations(100); .setNumIterations(100);
LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model); int numAccurate = validatePrediction(validationData, model);
@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 0.0; double A = 0.0;
double B = -2.5; double B = -2.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize( JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionModel model = LogisticRegressionWithSGD.train( LogisticRegressionModel model = LogisticRegressionWithSGD.train(
testRDD.rdd(), 100, 1.0, 1.0); testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model); int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);

View file

@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
public class JavaNaiveBayesSuite implements Serializable { public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaNaiveBayesSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaNaiveBayesSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
private static final List<LabeledPoint> POINTS = Arrays.asList( private static final List<LabeledPoint> POINTS = Arrays.asList(
@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable {
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) { private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
int correct = 0; int correct = 0;
for (LabeledPoint p: points) { for (LabeledPoint p : points) {
if (model.predict(p.features()) == p.label()) { if (model.predict(p.features()) == p.label()) {
correct += 1; correct += 1;
} }
@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test @Test
public void runUsingConstructor() { public void runUsingConstructor() {
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache(); JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayes nb = new NaiveBayes().setLambda(1.0); NaiveBayes nb = new NaiveBayes().setLambda(1.0);
NaiveBayesModel model = nb.run(testRDD.rdd()); NaiveBayesModel model = nb.run(testRDD.rdd());
@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test @Test
public void runUsingStaticMethods() { public void runUsingStaticMethods() {
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache(); JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd()); NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
int numAccurate1 = validatePrediction(POINTS, model1); int numAccurate1 = validatePrediction(POINTS, model1);
@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test @Test
public void testPredictJavaRDD() { public void testPredictJavaRDD() {
JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache(); JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model = NaiveBayes.train(examples.rdd()); NaiveBayesModel model = NaiveBayes.train(examples.rdd());
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() { JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
@Override @Override
public Vector call(LabeledPoint v) throws Exception { public Vector call(LabeledPoint v) throws Exception {
return v.features(); return v.features();
}}); }
});
JavaRDD<Double> predictions = model.predict(vectors); JavaRDD<Double> predictions = model.predict(vectors);
// Should be able to get the first prediction. // Should be able to get the first prediction.
predictions.first(); predictions.first();

View file

@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
public class JavaSVMSuite implements Serializable { public class JavaSVMSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaSVMSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaSVMSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) { int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
int numAccurate = 0; int numAccurate = 0;
for (LabeledPoint point: validationData) { for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features()); Double prediction = model.predict(point.features());
if (prediction == point.label()) { if (prediction == point.label()) {
numAccurate++; numAccurate++;
@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable {
double A = 2.0; double A = 2.0;
double[] weights = {-1.5, 1.0}; double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
weights, nPoints, 42), 2).cache(); weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD(); SVMWithSGD svmSGDImpl = new SVMWithSGD();
svmSGDImpl.setIntercept(true); svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0) svmSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0) .setRegParam(1.0)
.setNumIterations(100); .setNumIterations(100);
SVMModel model = svmSGDImpl.run(testRDD.rdd()); SVMModel model = svmSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model); int numAccurate = validatePrediction(validationData, model);
@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable {
double A = 0.0; double A = 0.0;
double[] weights = {-1.5, 1.0}; double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
weights, nPoints, 42), 2).cache(); weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);

View file

@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering;
import java.io.Serializable; import java.io.Serializable;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaBisectingKMeansSuite implements Serializable { public class JavaBisectingKMeansSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", this.getClass().getSimpleName()); spark = SparkSession.builder()
.master("local")
.appName("JavaBisectingKMeansSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
public void twoDimensionalData() { public void twoDimensionalData() {
JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList( JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList(
Vectors.dense(4, -1), Vectors.dense(4, -1),
Vectors.dense(4, 1), Vectors.dense(4, 1),
Vectors.sparse(2, new int[] {0}, new double[] {1.0}) Vectors.sparse(2, new int[]{0}, new double[]{1.0})
), 2); ), 2);
BisectingKMeans bkm = new BisectingKMeans() BisectingKMeans bkm = new BisectingKMeans()
@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable {
.setSeed(1L); .setSeed(1L);
BisectingKMeansModel model = bkm.run(points); BisectingKMeansModel model = bkm.run(points);
Assert.assertEquals(3, model.k()); Assert.assertEquals(3, model.k());
Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12);
for (ClusteringTreeNode child: model.root().children()) { for (ClusteringTreeNode child : model.root().children()) {
double[] center = child.center().toArray(); double[] center = child.center().toArray();
if (center[0] > 2) { if (center[0] > 2) {
Assert.assertEquals(2, child.size()); Assert.assertEquals(2, child.size());
Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
} else { } else {
Assert.assertEquals(1, child.size()); Assert.assertEquals(1, child.size());
Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
} }
} }
} }

View file

@ -21,29 +21,35 @@ import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaGaussianMixtureSuite implements Serializable { public class JavaGaussianMixtureSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaGaussianMixture"); spark = SparkSession.builder()
.master("local")
.appName("JavaGaussianMixture")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable {
Vectors.dense(1.0, 4.0, 6.0) Vectors.dense(1.0, 4.0, 6.0)
); );
JavaRDD<Vector> data = sc.parallelize(points, 2); JavaRDD<Vector> data = jsc.parallelize(points, 2);
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
.run(data); .run(data);
assertEquals(model.gaussians().length, 2); assertEquals(model.gaussians().length, 2);

View file

@ -21,28 +21,35 @@ import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable { public class JavaKMeansSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaKMeans"); spark = SparkSession.builder()
.master("local")
.appName("JavaKMeans")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
JavaRDD<Vector> data = sc.parallelize(points, 2); JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL()); KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
assertEquals(1, model.clusterCenters().length); assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]); assertEquals(expectedCenter, model.clusterCenters()[0]);
@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
JavaRDD<Vector> data = sc.parallelize(points, 2); JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
assertEquals(1, model.clusterCenters().length); assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]); assertEquals(expectedCenter, model.clusterCenters()[0]);
@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable {
Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0) Vectors.dense(1.0, 4.0, 6.0)
); );
JavaRDD<Vector> data = sc.parallelize(points, 2); JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
JavaRDD<Integer> predictions = model.predict(data); JavaRDD<Integer> predictions = model.predict(data);
// Should be able to get the first prediction. // Should be able to get the first prediction.

View file

@ -27,37 +27,42 @@ import scala.Tuple3;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaLDASuite implements Serializable { public class JavaLDASuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaLDA"); spark = SparkSession.builder()
.master("local")
.appName("JavaLDASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>(); ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(), tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
LDASuite.tinyCorpus()[i]._2())); LDASuite.tinyCorpus()[i]._2()));
} }
JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2); JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable {
.setMaxIterations(5) .setMaxIterations(5)
.setSeed(12345); .setSeed(12345);
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
// Check: basic parameters // Check: basic parameters
LocalLDAModel localModel = model.toLocal(); LocalLDAModel localModel = model.toLocal();
@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable {
public Boolean call(Tuple2<Long, Vector> tuple2) { public Boolean call(Tuple2<Long, Vector> tuple2) {
return Vectors.norm(tuple2._2(), 1.0) != 0.0; return Vectors.norm(tuple2._2(), 1.0) != 0.0;
} }
}); });
assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
// Check: javaTopTopicsPerDocuments // Check: javaTopTopicsPerDocuments
@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable {
@Test @Test
public void localLdaMethods() { public void localLdaMethods() {
JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2); JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs); JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
// check: topicDistributions // check: topicDistributions
@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable {
// check: logLikelihood. // check: logLikelihood.
ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>(); ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0))); docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
double logLikelihood = toyModel.logLikelihood(single); double logLikelihood = toyModel.logLikelihood(single);
} }
@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable {
private static int tinyVocabSize = LDASuite.tinyVocabSize(); private static int tinyVocabSize = LDASuite.tinyVocabSize();
private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Matrix tinyTopics = LDASuite.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription = private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite.tinyTopicDescription(); LDASuite.tinyTopicDescription();
private JavaPairRDD<Long, Vector> corpus; private JavaPairRDD<Long, Vector> corpus;
private LocalLDAModel toyModel = LDASuite.toyModel(); private LocalLDAModel toyModel = LDASuite.toyModel();
private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData(); private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();

View file

@ -27,8 +27,6 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.apache.spark.streaming.JavaTestUtils.*;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStreamingKMeansSuite implements Serializable { public class JavaStreamingKMeansSuite implements Serializable {

View file

@ -31,27 +31,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaRankingMetricsSuite implements Serializable { public class JavaRankingMetricsSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels; private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); spark = SparkSession.builder()
predictionAndLabels = sc.parallelize(Arrays.asList( .master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
predictionAndLabels = jsc.parallelize(Arrays.asList(
Tuple2$.MODULE$.apply( Tuple2$.MODULE$.apply(
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
Tuple2$.MODULE$.apply( Tuple2$.MODULE$.apply(
Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Tuple2$.MODULE$.apply( Tuple2$.MODULE$.apply(
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2); Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test

View file

@ -29,19 +29,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.SparkSession;
public class JavaTfIdfSuite implements Serializable { public class JavaTfIdfSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaTfIdfSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility. // The tests are to check Java compatibility.
HashingTF tf = new HashingTF(); HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList( JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2); Arrays.asList("this is still a sentence".split(" "))), 2);
@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs); JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect(); List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this"); int indexOfThis = tf.indexOf("this");
for (Vector v: localTfIdfs) { for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
} }
} }
@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility. // The tests are to check Java compatibility.
HashingTF tf = new HashingTF(); HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList( JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2); Arrays.asList("this is still a sentence".split(" "))), 2);
@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs); JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect(); List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this"); int indexOfThis = tf.indexOf("this");
for (Vector v: localTfIdfs) { for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
} }
} }

View file

@ -21,9 +21,10 @@ import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import com.google.common.base.Strings;
import scala.Tuple2; import scala.Tuple2;
import com.google.common.base.Strings;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
@ -31,19 +32,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaWord2VecSuite implements Serializable { public class JavaWord2VecSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaWord2VecSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable {
String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
List<String> words = Arrays.asList(sentence.split(" ")); List<String> words = Arrays.asList(sentence.split(" "));
List<List<String>> localDoc = Arrays.asList(words, words); List<List<String>> localDoc = Arrays.asList(words, words);
JavaRDD<List<String>> doc = sc.parallelize(localDoc); JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
Word2Vec word2vec = new Word2Vec() Word2Vec word2vec = new Word2Vec()
.setVectorSize(10) .setVectorSize(10)
.setSeed(42L); .setSeed(42L);

View file

@ -26,32 +26,37 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
import org.apache.spark.sql.SparkSession;
public class JavaAssociationRulesSuite implements Serializable { public class JavaAssociationRulesSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth"); spark = SparkSession.builder()
.master("local")
.appName("JavaAssociationRulesSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
public void runAssociationRules() { public void runAssociationRules() {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Arrays.asList( JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = jsc.parallelize(Arrays.asList(
new FreqItemset<String>(new String[] {"a"}, 15L), new FreqItemset<String>(new String[]{"a"}, 15L),
new FreqItemset<String>(new String[] {"b"}, 35L), new FreqItemset<String>(new String[]{"b"}, 35L),
new FreqItemset<String>(new String[] {"a", "b"}, 12L) new FreqItemset<String>(new String[]{"a", "b"}, 12L)
)); ));
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets); JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
} }
} }

View file

@ -22,34 +22,41 @@ import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils; import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable { public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth"); spark = SparkSession.builder()
.master("local")
.appName("JavaFPGrowth")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
public void runFPGrowth() { public void runFPGrowth() {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList( JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")), Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")), Arrays.asList("s x o n r".split(" ")),
@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable {
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
assertEquals(18, freqItemsets.size()); assertEquals(18, freqItemsets.size());
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) { for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types. // Test return types.
List<String> items = itemset.javaItems(); List<String> items = itemset.javaItems();
long freq = itemset.freq(); long freq = itemset.freq();
@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable {
public void runFPGrowthSaveLoad() { public void runFPGrowthSaveLoad() {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList( JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")), Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")), Arrays.asList("s x o n r".split(" ")),
@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable {
String outputPath = tempDir.getPath(); String outputPath = tempDir.getPath();
try { try {
model.save(sc.sc(), outputPath); model.save(spark.sparkContext(), outputPath);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
FPGrowthModel<String> newModel = FPGrowthModel<String> newModel =
(FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath); (FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath);
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD() List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
.collect(); .collect();
assertEquals(18, freqItemsets.size()); assertEquals(18, freqItemsets.size());
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) { for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types. // Test return types.
List<String> items = itemset.javaItems(); List<String> items = itemset.javaItems();
long freq = itemset.freq(); long freq = itemset.freq();

View file

@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils; import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite { public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaPrefixSpan"); spark = SparkSession.builder()
.master("local")
.appName("JavaPrefixSpan")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
public void runPrefixSpan() { public void runPrefixSpan() {
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList( JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite {
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect(); List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size()); Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized. // Check that each frequent sequence could be materialized.
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) { for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence(); List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq(); long freq = freqSeq.freq();
} }
@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite {
@Test @Test
public void runPrefixSpanSaveLoad() { public void runPrefixSpanSaveLoad() {
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList( JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite {
String outputPath = tempDir.getPath(); String outputPath = tempDir.getPath();
try { try {
model.save(sc.sc(), outputPath); model.save(spark.sparkContext(), outputPath);
PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath);
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD(); JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect(); List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size()); Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized. // Check that each frequent sequence could be materialized.
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) { for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence(); List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq(); long freq = freqSeq.freq();
} }

View file

@ -17,147 +17,149 @@
package org.apache.spark.mllib.linalg; package org.apache.spark.mllib.linalg;
import static org.junit.Assert.*;
import org.junit.Test;
import java.io.Serializable; import java.io.Serializable;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
public class JavaMatricesSuite implements Serializable { public class JavaMatricesSuite implements Serializable {
@Test @Test
public void randMatrixConstruction() { public void randMatrixConstruction() {
Random rng = new Random(24); Random rng = new Random(24);
Matrix r = Matrices.rand(3, 4, rng); Matrix r = Matrices.rand(3, 4, rng);
rng.setSeed(24); rng.setSeed(24);
DenseMatrix dr = DenseMatrix.rand(3, 4, rng); DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0); assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
rng.setSeed(24); rng.setSeed(24);
Matrix rn = Matrices.randn(3, 4, rng); Matrix rn = Matrices.randn(3, 4, rng);
rng.setSeed(24); rng.setSeed(24);
DenseMatrix drn = DenseMatrix.randn(3, 4, rng); DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
rng.setSeed(24); rng.setSeed(24);
Matrix s = Matrices.sprand(3, 4, 0.5, rng); Matrix s = Matrices.sprand(3, 4, 0.5, rng);
rng.setSeed(24); rng.setSeed(24);
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
assertArrayEquals(s.toArray(), sr.toArray(), 0.0); assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
rng.setSeed(24); rng.setSeed(24);
Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
rng.setSeed(24); rng.setSeed(24);
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
} }
@Test @Test
public void identityMatrixConstruction() { public void identityMatrixConstruction() {
Matrix r = Matrices.eye(2); Matrix r = Matrices.eye(2);
DenseMatrix dr = DenseMatrix.eye(2); DenseMatrix dr = DenseMatrix.eye(2);
SparseMatrix sr = SparseMatrix.speye(2); SparseMatrix sr = SparseMatrix.speye(2);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0); assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
} }
@Test @Test
public void diagonalMatrixConstruction() { public void diagonalMatrixConstruction() {
Vector v = Vectors.dense(1.0, 0.0, 2.0); Vector v = Vectors.dense(1.0, 0.0, 2.0);
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
Matrix m = Matrices.diag(v); Matrix m = Matrices.diag(v);
Matrix sm = Matrices.diag(sv); Matrix sm = Matrices.diag(sv);
DenseMatrix d = DenseMatrix.diag(v); DenseMatrix d = DenseMatrix.diag(v);
DenseMatrix sd = DenseMatrix.diag(sv); DenseMatrix sd = DenseMatrix.diag(sv);
SparseMatrix s = SparseMatrix.spdiag(v); SparseMatrix s = SparseMatrix.spdiag(v);
SparseMatrix ss = SparseMatrix.spdiag(sv); SparseMatrix ss = SparseMatrix.spdiag(sv);
assertArrayEquals(m.toArray(), sm.toArray(), 0.0); assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sm.toArray(), 0.0); assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sd.toArray(), 0.0); assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
assertArrayEquals(sd.toArray(), s.toArray(), 0.0); assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
assertArrayEquals(s.toArray(), ss.toArray(), 0.0); assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
assertArrayEquals(s.values(), ss.values(), 0.0); assertArrayEquals(s.values(), ss.values(), 0.0);
assertEquals(2, s.values().length); assertEquals(2, s.values().length);
assertEquals(2, ss.values().length); assertEquals(2, ss.values().length);
assertEquals(4, s.colPtrs().length); assertEquals(4, s.colPtrs().length);
assertEquals(4, ss.colPtrs().length); assertEquals(4, ss.colPtrs().length);
} }
@Test @Test
public void zerosMatrixConstruction() { public void zerosMatrixConstruction() {
Matrix z = Matrices.zeros(2, 2); Matrix z = Matrices.zeros(2, 2);
Matrix one = Matrices.ones(2, 2); Matrix one = Matrices.ones(2, 2);
DenseMatrix dz = DenseMatrix.zeros(2, 2); DenseMatrix dz = DenseMatrix.zeros(2, 2);
DenseMatrix done = DenseMatrix.ones(2, 2); DenseMatrix done = DenseMatrix.ones(2, 2);
assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
} }
@Test @Test
public void sparseDenseConversion() { public void sparseDenseConversion() {
int m = 3; int m = 3;
int n = 2; int n = 2;
double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
int[] colPtrs = new int[]{0, 2, 4}; int[] colPtrs = new int[]{0, 2, 4};
int[] rowIndices = new int[]{0, 1, 1, 2}; int[] rowIndices = new int[]{0, 1, 1, 2};
SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
SparseMatrix spMat2 = deMat1.toSparse(); SparseMatrix spMat2 = deMat1.toSparse();
DenseMatrix deMat2 = spMat1.toDense(); DenseMatrix deMat2 = spMat1.toDense();
assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
} }
@Test @Test
public void concatenateMatrices() { public void concatenateMatrices() {
int m = 3; int m = 3;
int n = 2; int n = 2;
Random rng = new Random(42); Random rng = new Random(42);
SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
rng.setSeed(42); rng.setSeed(42);
DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
Matrix deMat2 = Matrices.eye(3); Matrix deMat2 = Matrices.eye(3);
Matrix spMat2 = Matrices.speye(3); Matrix spMat2 = Matrices.speye(3);
Matrix deMat3 = Matrices.eye(2); Matrix deMat3 = Matrices.eye(2);
Matrix spMat3 = Matrices.speye(2); Matrix spMat3 = Matrices.speye(2);
Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
assertEquals(3, deHorz1.numRows()); assertEquals(3, deHorz1.numRows());
assertEquals(3, deHorz2.numRows()); assertEquals(3, deHorz2.numRows());
assertEquals(3, deHorz3.numRows()); assertEquals(3, deHorz3.numRows());
assertEquals(3, spHorz.numRows()); assertEquals(3, spHorz.numRows());
assertEquals(5, deHorz1.numCols()); assertEquals(5, deHorz1.numCols());
assertEquals(5, deHorz2.numCols()); assertEquals(5, deHorz2.numCols());
assertEquals(5, deHorz3.numCols()); assertEquals(5, deHorz3.numCols());
assertEquals(5, spHorz.numCols()); assertEquals(5, spHorz.numCols());
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
assertEquals(5, deVert1.numRows()); assertEquals(5, deVert1.numRows());
assertEquals(5, deVert2.numRows()); assertEquals(5, deVert2.numRows());
assertEquals(5, deVert3.numRows()); assertEquals(5, deVert3.numRows());
assertEquals(5, spVert.numRows()); assertEquals(5, spVert.numRows());
assertEquals(2, deVert1.numCols()); assertEquals(2, deVert1.numCols());
assertEquals(2, deVert2.numCols()); assertEquals(2, deVert2.numCols());
assertEquals(2, deVert3.numCols()); assertEquals(2, deVert3.numCols());
assertEquals(2, spVert.numCols()); assertEquals(2, spVert.numCols());
} }
} }

View file

@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals;
import scala.Tuple2; import scala.Tuple2;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*;
public class JavaVectorsSuite implements Serializable { public class JavaVectorsSuite implements Serializable {
@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable {
public void sparseArrayConstruction() { public void sparseArrayConstruction() {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Vector v = Vectors.sparse(3, Arrays.asList( Vector v = Vectors.sparse(3, Arrays.asList(
new Tuple2<>(0, 2.0), new Tuple2<>(0, 2.0),
new Tuple2<>(2, 3.0))); new Tuple2<>(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
} }
} }

View file

@ -20,29 +20,35 @@ package org.apache.spark.mllib.random;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.junit.Assert;
import org.junit.After; import org.junit.After;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.random.RandomRDDs.*; import static org.apache.spark.mllib.random.RandomRDDs.*;
public class JavaRandomRDDsSuite { public class JavaRandomRDDsSuite {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaRandomRDDsSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L; long m = 1000L;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
} }
} }
@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L; long m = 1000L;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
} }
} }
@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L; long m = 1000L;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
} }
} }
@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L; long m = 1000L;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
} }
} }
@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L; long m = 1000L;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
} }
} }
@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite {
@Test @Test
public void testGammaRDD() { public void testGammaRDD() {
double shape = 1.0; double shape = 1.0;
double scale = 2.0; double jscale = 2.0;
long m = 1000L; long m = 1000L;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
} }
} }
@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite {
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n); JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p); JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite {
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n); JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p); JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite {
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite {
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n); JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite {
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testGammaVectorRDD() { public void testGammaVectorRDD() {
double shape = 1.0; double shape = 1.0;
double scale = 2.0; double jscale = 2.0;
long m = 100L; long m = 100L;
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite {
long seed = 1L; long seed = 1L;
int numPartitions = 0; int numPartitions = 0;
StringGenerator gen = new StringGenerator(); StringGenerator gen = new StringGenerator();
JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size); JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions); JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(size, rdd.count()); Assert.assertEquals(size, rdd.count());
Assert.assertEquals(2, rdd.first().length()); Assert.assertEquals(2, rdd.first().length());
} }
@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite {
int n = 10; int n = 10;
int p = 2; int p = 2;
long seed = 1L; long seed = 1L;
JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n); JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count()); Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size()); Assert.assertEquals(n, rdd.first().size());
} }
@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator<String>, Serializable {
public String nextValue() { public String nextValue() {
return "42"; return "42";
} }
@Override @Override
public StringGenerator copy() { public StringGenerator copy() {
return new StringGenerator(); return new StringGenerator();
} }
@Override @Override
public void setSeed(long seed) { public void setSeed(long seed) {
} }

View file

@ -32,40 +32,46 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaALSSuite implements Serializable { public class JavaALSSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaALS"); spark = SparkSession.builder()
.master("local")
.appName("JavaALS")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
private void validatePrediction( private void validatePrediction(
MatrixFactorizationModel model, MatrixFactorizationModel model,
int users, int users,
int products, int products,
double[] trueRatings, double[] trueRatings,
double matchThreshold, double matchThreshold,
boolean implicitPrefs, boolean implicitPrefs,
double[] truePrefs) { double[] truePrefs) {
List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products); List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
for (int u=0; u < users; ++u) { for (int u = 0; u < users; ++u) {
for (int p=0; p < products; ++p) { for (int p = 0; p < products; ++p) {
localUsersProducts.add(new Tuple2<>(u, p)); localUsersProducts.add(new Tuple2<>(u, p));
} }
} }
JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts); JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
List<Rating> predictedRatings = model.predict(usersProducts).collect(); List<Rating> predictedRatings = model.predict(usersProducts).collect();
Assert.assertEquals(users * products, predictedRatings.size()); Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) { if (!implicitPrefs) {
for (Rating r: predictedRatings) { for (Rating r : predictedRatings) {
double prediction = r.rating(); double prediction = r.rating();
double correct = trueRatings[r.product() * users + r.user()]; double correct = trueRatings[r.product() * users + r.user()];
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable {
// (ref Mahout's implicit ALS tests) // (ref Mahout's implicit ALS tests)
double sqErr = 0.0; double sqErr = 0.0;
double denom = 0.0; double denom = 0.0;
for (Rating r: predictedRatings) { for (Rating r : predictedRatings) {
double prediction = r.rating(); double prediction = r.rating();
double truePref = truePrefs[r.product() * users + r.user()]; double truePref = truePrefs[r.product() * users + r.user()];
double confidence = 1.0 + double confidence = 1.0 +
@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable {
int users = 50; int users = 50;
int products = 100; int products = 100;
Tuple3<List<Rating>, double[], double[]> testData = Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
} }
@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable {
int users = 100; int users = 100;
int products = 200; int products = 200;
Tuple3<List<Rating>, double[], double[]> testData = Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features) MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations) .setIterations(iterations)
@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable {
int users = 80; int users = 80;
int products = 160; int products = 160;
Tuple3<List<Rating>, double[], double[]> testData = Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
} }
@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable {
int users = 100; int users = 100;
int products = 200; int products = 200;
Tuple3<List<Rating>, double[], double[]> testData = Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features) MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations) .setIterations(iterations)
@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable {
int users = 80; int users = 80;
int products = 160; int products = 160;
Tuple3<List<Rating>, double[], double[]> testData = Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features) MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations) .setIterations(iterations)
.setImplicitPrefs(true) .setImplicitPrefs(true)
@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable {
int users = 200; int users = 200;
int products = 50; int products = 50;
List<Rating> testData = ALSSuite.generateRatingsAsJava( List<Rating> testData = ALSSuite.generateRatingsAsJava(
users, products, features, 0.7, true, false)._1(); users, products, features, 0.7, true, false)._1();
JavaRDD<Rating> data = sc.parallelize(testData); JavaRDD<Rating> data = jsc.parallelize(testData);
MatrixFactorizationModel model = new ALS().setRank(features) MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations) .setIterations(iterations)
.setImplicitPrefs(true) .setImplicitPrefs(true)
@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable {
private static void validateRecommendations(Rating[] recommendations, int howMany) { private static void validateRecommendations(Rating[] recommendations, int howMany) {
Assert.assertEquals(howMany, recommendations.length); Assert.assertEquals(howMany, recommendations.length);
for (int i = 1; i < recommendations.length; i++) { for (int i = 1; i < recommendations.length; i++) {
Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
} }
Assert.assertTrue(recommendations[0].rating() > 0.7); Assert.assertTrue(recommendations[0].rating() > 0.7);
} }

View file

@ -32,15 +32,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaIsotonicRegressionSuite implements Serializable { public class JavaIsotonicRegressionSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) { private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length); List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
for (int i = 1; i <= labels.length; i++) { for (int i = 1; i <= labels.length; i++) {
input.add(new Tuple3<>(labels[i-1], (double) i, 1.0)); input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
} }
return input; return input;
@ -48,20 +50,24 @@ public class JavaIsotonicRegressionSuite implements Serializable {
private IsotonicRegressionModel runIsotonicRegression(double[] labels) { private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
JavaRDD<Tuple3<Double, Double, Double>> trainRDD = JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
sc.parallelize(generateIsotonicInput(labels), 2).cache(); jsc.parallelize(generateIsotonicInput(labels), 2).cache();
return new IsotonicRegression().run(trainRDD); return new IsotonicRegression().run(trainRDD);
} }
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaLinearRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
@Test @Test
@ -70,7 +76,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
Assert.assertArrayEquals( Assert.assertArrayEquals(
new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
} }
@Test @Test
@ -78,7 +84,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
IsotonicRegressionModel model = IsotonicRegressionModel model =
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
List<Double> predictions = model.predict(testRDD).collect(); List<Double> predictions = model.predict(testRDD).collect();
Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14); Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);

View file

@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator; import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
public class JavaLassoSuite implements Serializable { public class JavaLassoSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaLassoSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaLassoSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) { int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
int numAccurate = 0; int numAccurate = 0;
for (LabeledPoint point: validationData) { for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features()); Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value. // A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) { if (Math.abs(prediction - point.label()) <= 0.5) {
@ -61,15 +67,15 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0; double A = 0.0;
double[] weights = {-1.5, 1.0e-2}; double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache(); weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD lassoSGDImpl = new LassoWithSGD(); LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0) lassoSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(0.01) .setRegParam(0.01)
.setNumIterations(20); .setNumIterations(20);
LassoModel model = lassoSGDImpl.run(testRDD.rdd()); LassoModel model = lassoSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model); int numAccurate = validatePrediction(validationData, model);
@ -82,10 +88,10 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0; double A = 0.0;
double[] weights = {-1.5, 1.0e-2}; double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache(); weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);

View file

@ -25,34 +25,40 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator; import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
public class JavaLinearRegressionSuite implements Serializable { public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaLinearRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) { int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
int numAccurate = 0; int numAccurate = 0;
for (LabeledPoint point: validationData) { for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features()); Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value. // A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) { if (Math.abs(prediction - point.label()) <= 0.5) {
numAccurate++; numAccurate++;
} }
} }
return numAccurate; return numAccurate;
} }
@ -63,10 +69,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 3.0; double A = 3.0;
double[] weights = {10, 10}; double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize( JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
linSGDImpl.setIntercept(true); linSGDImpl.setIntercept(true);
@ -82,10 +88,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 0.0; double A = 0.0;
double[] weights = {10, 10}; double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize( JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData = List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
@ -98,7 +104,7 @@ public class JavaLinearRegressionSuite implements Serializable {
int nPoints = 100; int nPoints = 100;
double A = 0.0; double A = 0.0;
double[] weights = {10, 10}; double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize( JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());

View file

@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator; import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
public class JavaRidgeRegressionSuite implements Serializable { public class JavaRidgeRegressionSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaRidgeRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
private static double predictionError(List<LabeledPoint> validationData, private static double predictionError(List<LabeledPoint> validationData,
RidgeRegressionModel model) { RidgeRegressionModel model) {
double errorSum = 0; double errorSum = 0;
for (LabeledPoint point: validationData) { for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features()); Double prediction = model.predict(point.features());
errorSum += (prediction - point.label()) * (prediction - point.label()); errorSum += (prediction - point.label()) * (prediction - point.label());
} }
@ -68,9 +74,9 @@ public class JavaRidgeRegressionSuite implements Serializable {
public void runRidgeRegressionUsingConstructor() { public void runRidgeRegressionUsingConstructor() {
int numExamples = 50; int numExamples = 50;
int numFeatures = 20; int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0); List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples)); JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples); List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
@ -94,7 +100,7 @@ public class JavaRidgeRegressionSuite implements Serializable {
int numFeatures = 20; int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0); List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples)); JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples); List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);

View file

@ -24,13 +24,11 @@ import java.util.List;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
@ -38,36 +36,42 @@ import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
import org.apache.spark.mllib.stat.test.StreamingTest; import org.apache.spark.mllib.stat.test.StreamingTest;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStatisticsSuite implements Serializable { public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient JavaStreamingContext ssc; private transient JavaStreamingContext ssc;
@Before @Before
public void setUp() { public void setUp() {
SparkConf conf = new SparkConf() SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("JavaStatistics")
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
sc = new JavaSparkContext(conf); spark = SparkSession.builder()
ssc = new JavaStreamingContext(sc, new Duration(1000)); .master("local[2]")
.appName("JavaStatistics")
.config(conf)
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
ssc = new JavaStreamingContext(jsc, new Duration(1000));
ssc.checkpoint("checkpoint"); ssc.checkpoint("checkpoint");
} }
@After @After
public void tearDown() { public void tearDown() {
spark.stop();
ssc.stop(); ssc.stop();
ssc = null; spark = null;
sc = null;
} }
@Test @Test
public void testCorr() { public void testCorr() {
JavaRDD<Double> x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
JavaRDD<Double> y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
Double corr1 = Statistics.corr(x, y); Double corr1 = Statistics.corr(x, y);
Double corr2 = Statistics.corr(x, y, "pearson"); Double corr2 = Statistics.corr(x, y, "pearson");
@ -77,7 +81,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test @Test
public void kolmogorovSmirnovTest() { public void kolmogorovSmirnovTest() {
JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
data, "norm", 0.0, 1.0); data, "norm", 0.0, 1.0);
@ -85,7 +89,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test @Test
public void chiSqTest() { public void chiSqTest() {
JavaRDD<LabeledPoint> data = sc.parallelize(Arrays.asList( JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));

View file

@ -35,25 +35,31 @@ import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy; import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini; import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.model.DecisionTreeModel; import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeSuite implements Serializable { public class JavaDecisionTreeSuite implements Serializable {
private transient JavaSparkContext sc; private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before @Before
public void setUp() { public void setUp() {
sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
} }
@After @After
public void tearDown() { public void tearDown() {
sc.stop(); spark.stop();
sc = null; spark = null;
} }
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) { int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
int numCorrect = 0; int numCorrect = 0;
for (LabeledPoint point: validationData) { for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features()); Double prediction = model.predict(point.features());
if (prediction == point.label()) { if (prediction == point.label()) {
numCorrect++; numCorrect++;
@ -65,7 +71,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test @Test
public void runDTUsingConstructor() { public void runDTUsingConstructor() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
JavaRDD<LabeledPoint> rdd = sc.parallelize(arr); JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>(); HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@ -73,7 +79,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2; int numClasses = 2;
int maxBins = 100; int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
maxBins, categoricalFeaturesInfo); maxBins, categoricalFeaturesInfo);
DecisionTree learner = new DecisionTree(strategy); DecisionTree learner = new DecisionTree(strategy);
DecisionTreeModel model = learner.run(rdd.rdd()); DecisionTreeModel model = learner.run(rdd.rdd());
@ -85,7 +91,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test @Test
public void runDTUsingStaticMethods() { public void runDTUsingStaticMethods() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
JavaRDD<LabeledPoint> rdd = sc.parallelize(arr); JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>(); HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@ -93,7 +99,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2; int numClasses = 2;
int maxBins = 100; int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
maxBins, categoricalFeaturesInfo); maxBins, categoricalFeaturesInfo);
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);

View file

@ -183,7 +183,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
} }
test("pipeline validateParams") { test("pipeline validateParams") {
val df = sqlContext.createDataFrame( val df = spark.createDataFrame(
Seq( Seq(
(1, Vectors.dense(0.0, 1.0, 4.0), 1.0), (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, Vectors.dense(1.0, 0.0, 4.0), 2.0), (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),

View file

@ -32,7 +32,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("extractLabeledPoints") { test("extractLabeledPoints") {
def getTestData(labels: Seq[Double]): DataFrame = { def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data) spark.createDataFrame(data)
} }
val c = new MockClassifier val c = new MockClassifier
@ -72,7 +72,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("getNumClasses") { test("getNumClasses") {
def getTestData(labels: Seq[Double]): DataFrame = { def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data) spark.createDataFrame(data)
} }
val c = new MockClassifier val c = new MockClassifier

View file

@ -337,13 +337,13 @@ class DecisionTreeClassifierSuite
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1) val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
dt, isClassification = true, sqlContext) { (expected, actual) => dt, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual) TreeTests.checkEqual(expected, actual)
} }
} }
test("Fitting without numClasses in metadata") { test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val dt = new DecisionTreeClassifier().setMaxDepth(1) val dt = new DecisionTreeClassifier().setMaxDepth(1)
dt.fit(df) dt.fit(df)
} }

View file

@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1) val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
gbt, isClassification = true, sqlContext) { (expected, actual) => gbt, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual) TreeTests.checkEqual(expected, actual)
} }
} }
@ -130,7 +130,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
*/ */
test("Fitting without numClasses in metadata") { test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
gbt.fit(df) gbt.fit(df)
} }
@ -138,7 +138,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("extractLabeledPoints with bad data") { test("extractLabeledPoints with bad data") {
def getTestData(labels: Seq[Double]): DataFrame = { def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data) spark.createDataFrame(data)
} }
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)

View file

@ -42,7 +42,7 @@ class LogisticRegressionSuite
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) dataset = spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
binaryDataset = { binaryDataset = {
val nPoints = 10000 val nPoints = 10000
@ -54,7 +54,7 @@ class LogisticRegressionSuite
generateMultinomialLogisticInput(coefficients, xMean, xVariance, generateMultinomialLogisticInput(coefficients, xMean, xVariance,
addIntercept = true, nPoints, 42) addIntercept = true, nPoints, 42)
sqlContext.createDataFrame(sc.parallelize(testData, 4)) spark.createDataFrame(sc.parallelize(testData, 4))
} }
} }
@ -202,7 +202,7 @@ class LogisticRegressionSuite
} }
test("logistic regression: Predictor, Classifier methods") { test("logistic regression: Predictor, Classifier methods") {
val sqlContext = this.sqlContext val spark = this.spark
val lr = new LogisticRegression val lr = new LogisticRegression
val model = lr.fit(dataset) val model = lr.fit(dataset)
@ -864,8 +864,8 @@ class LogisticRegressionSuite
} }
} }
(sqlContext.createDataFrame(sc.parallelize(data1, 4)), (spark.createDataFrame(sc.parallelize(data1, 4)),
sqlContext.createDataFrame(sc.parallelize(data2, 4))) spark.createDataFrame(sc.parallelize(data2, 4)))
} }
val trainer1a = (new LogisticRegression).setFitIntercept(true) val trainer1a = (new LogisticRegression).setFitIntercept(true)
@ -938,7 +938,7 @@ class LogisticRegressionSuite
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1) val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
lr, isClassification = true, sqlContext) { (expected, actual) => lr, isClassification = true, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept) assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray) assert(expected.coefficients.toArray === actual.coefficients.toArray)
} }

View file

@ -36,7 +36,7 @@ class MultilayerPerceptronClassifierSuite
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = sqlContext.createDataFrame(Seq( dataset = spark.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0),
@ -77,7 +77,7 @@ class MultilayerPerceptronClassifierSuite
} }
test("Test setWeights by training restart") { test("Test setWeights by training restart") {
val dataFrame = sqlContext.createDataFrame(Seq( val dataFrame = spark.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0),
@ -113,7 +113,7 @@ class MultilayerPerceptronClassifierSuite
// the input seed is somewhat magic, to make this test pass // the input seed is somewhat magic, to make this test pass
val rdd = sc.parallelize(generateMultinomialLogisticInput( val rdd = sc.parallelize(generateMultinomialLogisticInput(
coefficients, xMean, xVariance, true, nPoints, 1), 2) coefficients, xMean, xVariance, true, nPoints, 1), 2)
val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") val dataFrame = spark.createDataFrame(rdd).toDF("label", "features")
val numClasses = 3 val numClasses = 3
val numIterations = 100 val numIterations = 100
val layers = Array[Int](4, 5, 4, numClasses) val layers = Array[Int](4, 5, 4, numClasses)
@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[ MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
mpc, isClassification = true, sqlContext) { (expected, actual) => mpc, isClassification = true, spark) { (expected, actual) =>
assert(expected.layers === actual.layers) assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights) assert(expected.weights === actual.weights)
} }

View file

@ -43,7 +43,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Array(0.10, 0.10, 0.70, 0.10) // label 2 Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log)) ).map(_.map(math.log))
dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
} }
def validatePrediction(predictionAndLabels: DataFrame): Unit = { def validatePrediction(predictionAndLabels: DataFrame): Unit = {
@ -127,7 +127,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val pi = Vectors.dense(piArray) val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( val testDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 42, "multinomial")) piArray, thetaArray, nPoints, 42, "multinomial"))
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset) val model = nb.fit(testDataset)
@ -135,7 +135,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model) validateModelFit(pi, theta, model)
assert(model.hasParent) assert(model.hasParent)
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial")) piArray, thetaArray, nPoints, 17, "multinomial"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
@ -157,7 +157,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val pi = Vectors.dense(piArray) val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( val testDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 45, "bernoulli")) piArray, thetaArray, nPoints, 45, "bernoulli"))
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset) val model = nb.fit(testDataset)
@ -165,7 +165,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model) validateModelFit(pi, theta, model)
assert(model.hasParent) assert(model.hasParent)
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli")) piArray, thetaArray, nPoints, 20, "bernoulli"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val nb = new NaiveBayes() val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
nb, isClassification = true, sqlContext) { (expected, actual) => nb, isClassification = true, spark) { (expected, actual) =>
assert(expected.pi === actual.pi) assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta) assert(expected.theta === actual.theta)
} }

View file

@ -53,7 +53,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
rdd = sc.parallelize(generateMultinomialLogisticInput( rdd = sc.parallelize(generateMultinomialLogisticInput(
coefficients, xMean, xVariance, true, nPoints, 42), 2) coefficients, xMean, xVariance, true, nPoints, 42), 2)
dataset = sqlContext.createDataFrame(rdd) dataset = spark.createDataFrame(rdd)
} }
test("params") { test("params") {
@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
ovr, isClassification = true, sqlContext) { (expected, actual) => ovr, isClassification = true, spark) { (expected, actual) =>
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
assert(expectedModels.length === actualModels.length) assert(expectedModels.length === actualModels.length)

View file

@ -155,7 +155,7 @@ class RandomForestClassifierSuite
} }
test("Fitting without numClasses in metadata") { test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
rf.fit(df) rf.fit(df)
} }
@ -189,7 +189,7 @@ class RandomForestClassifierSuite
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val rf = new RandomForestClassifier().setMaxDepth(1) val rf = new RandomForestClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
rf, isClassification = true, sqlContext) { (expected, actual) => rf, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual) TreeTests.checkEqual(expected, actual)
} }
} }

View file

@ -30,7 +30,7 @@ class BisectingKMeansSuite
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
} }
test("default parameters") { test("default parameters") {

View file

@ -32,7 +32,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
} }
test("default parameters") { test("default parameters") {

View file

@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
private[clustering] case class TestRow(features: Vector) private[clustering] case class TestRow(features: Vector)
@ -34,7 +34,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
} }
test("default parameters") { test("default parameters") {
@ -142,11 +142,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
} }
object KMeansSuite { object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext val sc = spark.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v)) .map(v => new TestRow(v))
sql.createDataFrame(rdd) spark.createDataFrame(rdd)
} }
/** /**

View file

@ -17,30 +17,30 @@
package org.apache.spark.ml.clustering package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql._
object LDASuite { object LDASuite {
def generateLDAData( def generateLDAData(
sql: SQLContext, spark: SparkSession,
rows: Int, rows: Int,
k: Int, k: Int,
vocabSize: Int): DataFrame = { vocabSize: Int): DataFrame = {
val avgWC = 1 // average instances of each word in a doc val avgWC = 1 // average instances of each word in a doc
val sc = sql.sparkContext val sc = spark.sparkContext
val rng = new java.util.Random() val rng = new java.util.Random()
rng.setSeed(1) rng.setSeed(1)
val rdd = sc.parallelize(1 to rows).map { i => val rdd = sc.parallelize(1 to rows).map { i =>
Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
}.map(v => new TestRow(v)) }.map(v => new TestRow(v))
sql.createDataFrame(rdd) spark.createDataFrame(rdd)
} }
/** /**
@ -68,7 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) dataset = LDASuite.generateLDAData(spark, 50, k, vocabSize)
} }
test("default parameters") { test("default parameters") {
@ -140,7 +140,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
new LDA().setTopicConcentration(-1.1) new LDA().setTopicConcentration(-1.1)
} }
val dummyDF = sqlContext.createDataFrame(Seq( val dummyDF = spark.createDataFrame(Seq(
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
// validate parameters // validate parameters
lda.transformSchema(dummyDF.schema) lda.transformSchema(dummyDF.schema)
@ -274,7 +274,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
// There should be 1 checkpoint remaining. // There should be 1 checkpoint remaining.
assert(model.getCheckpointFiles.length === 1) assert(model.getCheckpointFiles.length === 1)
val checkpointFile = new Path(model.getCheckpointFiles.head) val checkpointFile = new Path(model.getCheckpointFiles.head)
val fs = checkpointFile.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration)
assert(fs.exists(checkpointFile)) assert(fs.exists(checkpointFile))
model.deleteCheckpointFiles() model.deleteCheckpointFiles()
assert(model.getCheckpointFiles.isEmpty) assert(model.getCheckpointFiles.isEmpty)

View file

@ -42,21 +42,21 @@ class BinaryClassificationEvaluatorSuite
val evaluator = new BinaryClassificationEvaluator() val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR") .setMetricName("areaUnderPR")
val vectorDF = sqlContext.createDataFrame(Seq( val vectorDF = spark.createDataFrame(Seq(
(0d, Vectors.dense(12, 2.5)), (0d, Vectors.dense(12, 2.5)),
(1d, Vectors.dense(1, 3)), (1d, Vectors.dense(1, 3)),
(0d, Vectors.dense(10, 2)) (0d, Vectors.dense(10, 2))
)).toDF("label", "rawPrediction") )).toDF("label", "rawPrediction")
assert(evaluator.evaluate(vectorDF) === 1.0) assert(evaluator.evaluate(vectorDF) === 1.0)
val doubleDF = sqlContext.createDataFrame(Seq( val doubleDF = spark.createDataFrame(Seq(
(0d, 0d), (0d, 0d),
(1d, 1d), (1d, 1d),
(0d, 0d) (0d, 0d)
)).toDF("label", "rawPrediction") )).toDF("label", "rawPrediction")
assert(evaluator.evaluate(doubleDF) === 1.0) assert(evaluator.evaluate(doubleDF) === 1.0)
val stringDF = sqlContext.createDataFrame(Seq( val stringDF = spark.createDataFrame(Seq(
(0d, "0d"), (0d, "0d"),
(1d, "1d"), (1d, "1d"),
(0d, "0d") (0d, "0d")
@ -71,6 +71,6 @@ class BinaryClassificationEvaluatorSuite
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
MLTestingUtils.checkNumericTypes(evaluator, sqlContext) MLTestingUtils.checkNumericTypes(evaluator, spark)
} }
} }

View file

@ -38,6 +38,6 @@ class MulticlassClassificationEvaluatorSuite
} }
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext) MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark)
} }
} }

View file

@ -42,7 +42,7 @@ class RegressionEvaluatorSuite
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
* .saveAsTextFile("path") * .saveAsTextFile("path")
*/ */
val dataset = sqlContext.createDataFrame( val dataset = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput( sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
@ -85,6 +85,6 @@ class RegressionEvaluatorSuite
} }
test("should support all NumericType labels and not support other types") { test("should support all NumericType labels and not support other types") {
MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext) MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark)
} }
} }

View file

@ -39,7 +39,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize continuous features with default parameter") { test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame( val dataFrame: DataFrame = spark.createDataFrame(
data.zip(defaultBinarized)).toDF("feature", "expected") data.zip(defaultBinarized)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer() val binarizer: Binarizer = new Binarizer()
@ -55,7 +55,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize continuous features with setter") { test("Binarize continuous features with setter") {
val threshold: Double = 0.2 val threshold: Double = 0.2
val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame( val dataFrame: DataFrame = spark.createDataFrame(
data.zip(thresholdBinarized)).toDF("feature", "expected") data.zip(thresholdBinarized)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer() val binarizer: Binarizer = new Binarizer()
@ -71,7 +71,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize vector of continuous features with default parameter") { test("Binarize vector of continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( val dataFrame: DataFrame = spark.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized)) (Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected") )).toDF("feature", "expected")
@ -88,7 +88,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize vector of continuous features with setter") { test("Binarize vector of continuous features with setter") {
val threshold: Double = 0.2 val threshold: Double = 0.2
val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( val dataFrame: DataFrame = spark.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized)) (Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected") )).toDF("feature", "expected")

View file

@ -39,7 +39,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validData = Array(-0.5, -0.3, 0.0, 0.2) val validData = Array(-0.5, -0.3, 0.0, 0.2)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
val dataFrame: DataFrame = val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer() val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature") .setInputCol("feature")
@ -55,13 +55,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
// Check for exceptions when using a set of invalid feature values. // Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData val invalidData1: Array[Double] = Array(-0.9) ++ validData
val invalidData2 = Array(0.51) ++ validData val invalidData2 = Array(0.51) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
intercept[SparkException] { intercept[SparkException] {
bucketizer.transform(badDF1).collect() bucketizer.transform(badDF1).collect()
} }
} }
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
intercept[SparkException] { intercept[SparkException] {
bucketizer.transform(badDF2).collect() bucketizer.transform(badDF2).collect()
@ -74,7 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
val dataFrame: DataFrame = val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer() val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature") .setInputCol("feature")

View file

@ -24,14 +24,17 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.{Row, SparkSession}
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest { with DefaultReadWriteTest {
test("Test Chi-Square selector") { test("Test Chi-Square selector") {
val sqlContext = SQLContext.getOrCreate(sc) val spark = SparkSession.builder
import sqlContext.implicits._ .master("local[2]")
.appName("ChiSqSelectorSuite")
.getOrCreate()
import spark.implicits._
val data = Seq( val data = Seq(
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),

View file

@ -35,7 +35,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
private def split(s: String): Seq[String] = s.split("\\s+") private def split(s: String): Seq[String] = s.split("\\s+")
test("CountVectorizerModel common cases") { test("CountVectorizerModel common cases") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a b c d"), (0, split("a b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("a b b c d a"), (1, split("a b b c d a"),
@ -55,7 +55,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("CountVectorizer common cases") { test("CountVectorizer common cases") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a b c d e"), (0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
@ -76,7 +76,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("CountVectorizer vocabSize and minDF") { test("CountVectorizer vocabSize and minDF") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
@ -118,7 +118,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
test("CountVectorizer throws exception when vocab is empty") { test("CountVectorizer throws exception when vocab is empty") {
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a a b b c c")), (0, split("a a b b c c")),
(1, split("aa bb cc"))) (1, split("aa bb cc")))
).toDF("id", "words") ).toDF("id", "words")
@ -132,7 +132,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("CountVectorizerModel with minTF count") { test("CountVectorizerModel with minTF count") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq())), (2, split("a"), Vectors.sparse(4, Seq())),
@ -151,7 +151,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("CountVectorizerModel with minTF freq") { test("CountVectorizerModel with minTF freq") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
@ -170,7 +170,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("CountVectorizerModel and CountVectorizer with binary") { test("CountVectorizerModel and CountVectorizer with binary") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, split("a a a a b b b b c d"), (0, split("a a a a b b b b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),

View file

@ -63,7 +63,7 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
} }
val expectedResult = Vectors.dense(expectedResultBuffer) val expectedResult = Vectors.dense(expectedResultBuffer)
val dataset = sqlContext.createDataFrame(Seq( val dataset = spark.createDataFrame(Seq(
DCTTestData(data, expectedResult) DCTTestData(data, expectedResult)
)) ))

View file

@ -34,7 +34,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
} }
test("hashingTF") { test("hashingTF") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, "a a b b c d".split(" ").toSeq) (0, "a a b b c d".split(" ").toSeq)
)).toDF("id", "words") )).toDF("id", "words")
val n = 100 val n = 100
@ -54,7 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
} }
test("applying binary term freqs") { test("applying binary term freqs") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, "a a b c c c".split(" ").toSeq) (0, "a a b c c c".split(" ").toSeq)
)).toDF("id", "words") )).toDF("id", "words")
val n = 100 val n = 100

View file

@ -60,7 +60,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}) })
val expected = scaleDataWithIDF(data, idf) val expected = scaleDataWithIDF(data, idf)
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF() val idfModel = new IDF()
.setInputCol("features") .setInputCol("features")
@ -86,7 +86,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}) })
val expected = scaleDataWithIDF(data, idf) val expected = scaleDataWithIDF(data, idf)
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF() val idfModel = new IDF()
.setInputCol("features") .setInputCol("features")

View file

@ -59,7 +59,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
} }
test("numeric interaction") { test("numeric interaction") {
val data = sqlContext.createDataFrame( val data = spark.createDataFrame(
Seq( Seq(
(2, Vectors.dense(3.0, 4.0)), (2, Vectors.dense(3.0, 4.0)),
(1, Vectors.dense(1.0, 5.0))) (1, Vectors.dense(1.0, 5.0)))
@ -74,7 +74,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("b").as("b", groupAttr.toMetadata())) col("b").as("b", groupAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
val res = trans.transform(df) val res = trans.transform(df)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)))
@ -90,7 +90,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
} }
test("nominal interaction") { test("nominal interaction") {
val data = sqlContext.createDataFrame( val data = spark.createDataFrame(
Seq( Seq(
(2, Vectors.dense(3.0, 4.0)), (2, Vectors.dense(3.0, 4.0)),
(1, Vectors.dense(1.0, 5.0))) (1, Vectors.dense(1.0, 5.0)))
@ -106,7 +106,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("b").as("b", groupAttr.toMetadata())) col("b").as("b", groupAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
val res = trans.transform(df) val res = trans.transform(df)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)))
@ -126,7 +126,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
} }
test("default attr names") { test("default attr names") {
val data = sqlContext.createDataFrame( val data = spark.createDataFrame(
Seq( Seq(
(2, Vectors.dense(0.0, 4.0), 1.0), (2, Vectors.dense(0.0, 4.0), 1.0),
(1, Vectors.dense(1.0, 5.0), 10.0)) (1, Vectors.dense(1.0, 5.0), 10.0))
@ -142,7 +142,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) col("c").as("c", NumericAttribute.defaultAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features")
val res = trans.transform(df) val res = trans.transform(df)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)),
(1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)))

View file

@ -36,7 +36,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
Vectors.sparse(3, Array(0, 2), Array(-1, -1)), Vectors.sparse(3, Array(0, 2), Array(-1, -1)),
Vectors.sparse(3, Array(0), Array(-0.75))) Vectors.sparse(3, Array(0), Array(-0.75)))
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val scaler = new MaxAbsScaler() val scaler = new MaxAbsScaler()
.setInputCol("features") .setInputCol("features")
.setOutputCol("scaled") .setOutputCol("scaled")

View file

@ -38,7 +38,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
Vectors.sparse(3, Array(0, 2), Array(5, 5)), Vectors.sparse(3, Array(0, 2), Array(5, 5)),
Vectors.sparse(3, Array(0), Array(-2.5))) Vectors.sparse(3, Array(0), Array(-2.5)))
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val scaler = new MinMaxScaler() val scaler = new MinMaxScaler()
.setInputCol("features") .setInputCol("features")
.setOutputCol("scaled") .setOutputCol("scaled")
@ -57,7 +57,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
test("MinMaxScaler arguments max must be larger than min") { test("MinMaxScaler arguments max must be larger than min") {
withClue("arguments max must be larger than min") { withClue("arguments max must be larger than min") {
val dummyDF = sqlContext.createDataFrame(Seq( val dummyDF = spark.createDataFrame(Seq(
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")

View file

@ -34,7 +34,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
val nGram = new NGram() val nGram = new NGram()
.setInputCol("inputTokens") .setInputCol("inputTokens")
.setOutputCol("nGrams") .setOutputCol("nGrams")
val dataset = sqlContext.createDataFrame(Seq( val dataset = spark.createDataFrame(Seq(
NGramTestData( NGramTestData(
Array("Test", "for", "ngram", "."), Array("Test", "for", "ngram", "."),
Array("Test for", "for ngram", "ngram .") Array("Test for", "for ngram", "ngram .")
@ -47,7 +47,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens") .setInputCol("inputTokens")
.setOutputCol("nGrams") .setOutputCol("nGrams")
.setN(4) .setN(4)
val dataset = sqlContext.createDataFrame(Seq( val dataset = spark.createDataFrame(Seq(
NGramTestData( NGramTestData(
Array("a", "b", "c", "d", "e"), Array("a", "b", "c", "d", "e"),
Array("a b c d", "b c d e") Array("a b c d", "b c d e")
@ -60,7 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens") .setInputCol("inputTokens")
.setOutputCol("nGrams") .setOutputCol("nGrams")
.setN(4) .setN(4)
val dataset = sqlContext.createDataFrame(Seq( val dataset = spark.createDataFrame(Seq(
NGramTestData( NGramTestData(
Array(), Array(),
Array() Array()
@ -73,7 +73,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens") .setInputCol("inputTokens")
.setOutputCol("nGrams") .setOutputCol("nGrams")
.setN(6) .setN(6)
val dataset = sqlContext.createDataFrame(Seq( val dataset = spark.createDataFrame(Seq(
NGramTestData( NGramTestData(
Array("a", "b", "c", "d", "e"), Array("a", "b", "c", "d", "e"),
Array() Array()

View file

@ -61,7 +61,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.sparse(3, Seq()) Vectors.sparse(3, Seq())
) )
dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer() normalizer = new Normalizer()
.setInputCol("features") .setInputCol("features")
.setOutputCol("normalized_features") .setOutputCol("normalized_features")

View file

@ -32,7 +32,7 @@ class OneHotEncoderSuite
def stringIndexed(): DataFrame = { def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label") val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")
@ -81,7 +81,7 @@ class OneHotEncoderSuite
test("input column with ML attribute") { test("input column with ML attribute") {
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
.select(col("size").as("size", attr.toMetadata())) .select(col("size").as("size", attr.toMetadata()))
val encoder = new OneHotEncoder() val encoder = new OneHotEncoder()
.setInputCol("size") .setInputCol("size")
@ -94,7 +94,7 @@ class OneHotEncoderSuite
} }
test("input column without ML attribute") { test("input column without ML attribute") {
val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
val encoder = new OneHotEncoder() val encoder = new OneHotEncoder()
.setInputCol("index") .setInputCol("index")
.setOutputCol("encoded") .setOutputCol("encoded")

View file

@ -49,7 +49,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val pc = mat.computePrincipalComponents(3) val pc = mat.computePrincipalComponents(3)
val expected = mat.multiply(pc).rows val expected = mat.multiply(pc).rows
val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
val pca = new PCA() val pca = new PCA()
.setInputCol("features") .setInputCol("features")

View file

@ -59,7 +59,7 @@ class PolynomialExpansionSuite
Vectors.sparse(19, Array.empty, Array.empty)) Vectors.sparse(19, Array.empty, Array.empty))
test("Polynomial expansion with default parameter") { test("Polynomial expansion with default parameter") {
val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion() val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features") .setInputCol("features")
@ -76,7 +76,7 @@ class PolynomialExpansionSuite
} }
test("Polynomial expansion with setter") { test("Polynomial expansion with setter") {
val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion() val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features") .setInputCol("features")
@ -94,7 +94,7 @@ class PolynomialExpansionSuite
} }
test("Polynomial expansion with degree 1 is identity on vectors") { test("Polynomial expansion with degree 1 is identity on vectors") {
val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion() val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features") .setInputCol("features")

View file

@ -32,12 +32,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("transform numeric data") { test("transform numeric data") {
val formula = new RFormula().setFormula("id ~ v1 + v2") val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema) val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
@ -50,7 +50,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("features column already exists") { test("features column already exists") {
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
formula.fit(original) formula.fit(original)
} }
@ -61,7 +61,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("label column already exists") { test("label column already exists") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
val model = formula.fit(original) val model = formula.fit(original)
val resultSchema = model.transformSchema(original.schema) val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3) assert(resultSchema.length == 3)
@ -70,7 +70,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("label column already exists but is not double type") { test("label column already exists but is not double type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
val model = formula.fit(original) val model = formula.fit(original)
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
model.transformSchema(original.schema) model.transformSchema(original.schema)
@ -82,7 +82,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("allow missing label column for test datasets") { test("allow missing label column for test datasets") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
val model = formula.fit(original) val model = formula.fit(original)
val resultSchema = model.transformSchema(original.schema) val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3) assert(resultSchema.length == 3)
@ -91,14 +91,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
} }
test("allow empty label") { test("allow empty label") {
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)) Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
).toDF("id", "a", "b") ).toDF("id", "a", "b")
val formula = new RFormula().setFormula("~ a + b") val formula = new RFormula().setFormula("~ a + b")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema) val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
(4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
@ -110,13 +110,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("encodes string terms") { test("encodes string terms") {
val formula = new RFormula().setFormula("id ~ a + b") val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b") ).toDF("id", "a", "b")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema) val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
@ -129,13 +129,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("index string label") { test("index string label") {
val formula = new RFormula().setFormula("id ~ a + b") val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
).toDF("id", "a", "b") ).toDF("id", "a", "b")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema) val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
@ -148,7 +148,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("attribute generation") { test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b") val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b") ).toDF("id", "a", "b")
val model = formula.fit(original) val model = formula.fit(original)
@ -165,7 +165,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("vector attribute generation") { test("vector attribute generation") {
val formula = new RFormula().setFormula("id ~ vec") val formula = new RFormula().setFormula("id ~ vec")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec") ).toDF("id", "vec")
val model = formula.fit(original) val model = formula.fit(original)
@ -181,7 +181,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("vector attribute generation with unnamed input attrs") { test("vector attribute generation with unnamed input attrs") {
val formula = new RFormula().setFormula("id ~ vec2") val formula = new RFormula().setFormula("id ~ vec2")
val base = sqlContext.createDataFrame( val base = spark.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec") ).toDF("id", "vec")
val metadata = new AttributeGroup( val metadata = new AttributeGroup(
@ -203,12 +203,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("numeric interaction") { test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d") val formula = new RFormula().setFormula("a ~ b:c:d")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, 2, 4, 2), (2, 3, 4, 1)) Seq((1, 2, 4, 2), (2, 3, 4, 1))
).toDF("a", "b", "c", "d") ).toDF("a", "b", "c", "d")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(1, 2, 4, 2, Vectors.dense(16.0), 1.0), (1, 2, 4, 2, Vectors.dense(16.0), 1.0),
(2, 3, 4, 1, Vectors.dense(12.0), 2.0)) (2, 3, 4, 1, Vectors.dense(12.0), 2.0))
@ -223,12 +223,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("factor numeric interaction") { test("factor numeric interaction") {
val formula = new RFormula().setFormula("id ~ a:b") val formula = new RFormula().setFormula("id ~ a:b")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
).toDF("id", "a", "b") ).toDF("id", "a", "b")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
@ -250,12 +250,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("factor factor interaction") { test("factor factor interaction") {
val formula = new RFormula().setFormula("id ~ a:b") val formula = new RFormula().setFormula("id ~ a:b")
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b") ).toDF("id", "a", "b")
val model = formula.fit(original) val model = formula.fit(original)
val result = model.transform(original) val result = model.transform(original)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq( Seq(
(1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
@ -299,7 +299,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
} }
} }
val dataset = sqlContext.createDataFrame( val dataset = spark.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b") ).toDF("id", "a", "b")

View file

@ -31,13 +31,13 @@ class SQLTransformerSuite
} }
test("transform numeric data") { test("transform numeric data") {
val original = sqlContext.createDataFrame( val original = spark.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
val sqlTrans = new SQLTransformer().setStatement( val sqlTrans = new SQLTransformer().setStatement(
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
val result = sqlTrans.transform(original) val result = sqlTrans.transform(original)
val resultSchema = sqlTrans.transformSchema(original.schema) val resultSchema = sqlTrans.transformSchema(original.schema)
val expected = sqlContext.createDataFrame( val expected = spark.createDataFrame(
Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
.toDF("id", "v1", "v2", "v3", "v4") .toDF("id", "v1", "v2", "v3", "v4")
assert(result.schema.toString == resultSchema.toString) assert(result.schema.toString == resultSchema.toString)
@ -52,7 +52,7 @@ class SQLTransformerSuite
} }
test("transformSchema") { test("transformSchema") {
val df = sqlContext.range(10) val df = spark.range(10)
val outputSchema = new SQLTransformer() val outputSchema = new SQLTransformer()
.setStatement("SELECT id + 1 AS id1 FROM __THIS__") .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
.transformSchema(df.schema) .transformSchema(df.schema)

View file

@ -73,7 +73,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("Standardization with default parameter") { test("Standardization with default parameter") {
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
val standardScaler0 = new StandardScaler() val standardScaler0 = new StandardScaler()
.setInputCol("features") .setInputCol("features")
@ -84,9 +84,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("Standardization with setter") { test("Standardization with setter") {
val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
val standardScaler1 = new StandardScaler() val standardScaler1 = new StandardScaler()
.setInputCol("features") .setInputCol("features")

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.{Dataset, Row}
object StopWordsRemoverSuite extends SparkFunSuite { object StopWordsRemoverSuite extends SparkFunSuite {
def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
@ -42,7 +42,7 @@ class StopWordsRemoverSuite
val remover = new StopWordsRemover() val remover = new StopWordsRemover()
.setInputCol("raw") .setInputCol("raw")
.setOutputCol("filtered") .setOutputCol("filtered")
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("test", "test"), Seq("test", "test")), (Seq("test", "test"), Seq("test", "test")),
(Seq("a", "b", "c", "d"), Seq("b", "c")), (Seq("a", "b", "c", "d"), Seq("b", "c")),
(Seq("a", "the", "an"), Seq()), (Seq("a", "the", "an"), Seq()),
@ -60,7 +60,7 @@ class StopWordsRemoverSuite
.setInputCol("raw") .setInputCol("raw")
.setOutputCol("filtered") .setOutputCol("filtered")
.setStopWords(stopWords) .setStopWords(stopWords)
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("test", "test"), Seq()), (Seq("test", "test"), Seq()),
(Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
(Seq("a", "the", "an"), Seq()), (Seq("a", "the", "an"), Seq()),
@ -77,7 +77,7 @@ class StopWordsRemoverSuite
.setInputCol("raw") .setInputCol("raw")
.setOutputCol("filtered") .setOutputCol("filtered")
.setCaseSensitive(true) .setCaseSensitive(true)
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("A"), Seq("A")), (Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The")) (Seq("The", "the"), Seq("The"))
)).toDF("raw", "expected") )).toDF("raw", "expected")
@ -98,7 +98,7 @@ class StopWordsRemoverSuite
.setInputCol("raw") .setInputCol("raw")
.setOutputCol("filtered") .setOutputCol("filtered")
.setStopWords(stopWords) .setStopWords(stopWords)
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("acaba", "ama", "biri"), Seq()), (Seq("acaba", "ama", "biri"), Seq()),
(Seq("hep", "her", "scala"), Seq("scala")) (Seq("hep", "her", "scala"), Seq("scala"))
)).toDF("raw", "expected") )).toDF("raw", "expected")
@ -112,7 +112,7 @@ class StopWordsRemoverSuite
.setInputCol("raw") .setInputCol("raw")
.setOutputCol("filtered") .setOutputCol("filtered")
.setStopWords(stopWords.toArray) .setStopWords(stopWords.toArray)
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("python", "scala", "a"), Seq("python", "scala", "a")), (Seq("python", "scala", "a"), Seq("python", "scala", "a")),
(Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift")) (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
)).toDF("raw", "expected") )).toDF("raw", "expected")
@ -126,7 +126,7 @@ class StopWordsRemoverSuite
.setInputCol("raw") .setInputCol("raw")
.setOutputCol("filtered") .setOutputCol("filtered")
.setStopWords(stopWords.toArray) .setStopWords(stopWords.toArray)
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("python", "scala", "a"), Seq()), (Seq("python", "scala", "a"), Seq()),
(Seq("Python", "Scala", "swift"), Seq("swift")) (Seq("Python", "Scala", "swift"), Seq("swift"))
)).toDF("raw", "expected") )).toDF("raw", "expected")
@ -148,7 +148,7 @@ class StopWordsRemoverSuite
val remover = new StopWordsRemover() val remover = new StopWordsRemover()
.setInputCol("raw") .setInputCol("raw")
.setOutputCol(outputCol) .setOutputCol(outputCol)
val dataSet = sqlContext.createDataFrame(Seq( val dataSet = spark.createDataFrame(Seq(
(Seq("The", "the", "swift"), Seq("swift")) (Seq("The", "the", "swift"), Seq("swift"))
)).toDF("raw", outputCol) )).toDF("raw", outputCol)

View file

@ -39,7 +39,7 @@ class StringIndexerSuite
test("StringIndexer") { test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label") val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")
@ -63,8 +63,8 @@ class StringIndexerSuite
test("StringIndexerUnseen") { test("StringIndexerUnseen") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label") val df = spark.createDataFrame(data).toDF("id", "label")
val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") val df2 = spark.createDataFrame(data2).toDF("id", "label")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")
@ -93,7 +93,7 @@ class StringIndexerSuite
test("StringIndexer with a numeric input column") { test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label") val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")
@ -114,12 +114,12 @@ class StringIndexerSuite
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")
val df = sqlContext.range(0L, 10L).toDF() val df = spark.range(0L, 10L).toDF()
assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
} }
test("StringIndexerModel can't overwrite output column") { test("StringIndexerModel can't overwrite output column") {
val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("input") .setInputCol("input")
.setOutputCol("output") .setOutputCol("output")
@ -153,7 +153,7 @@ class StringIndexerSuite
test("IndexToString.transform") { test("IndexToString.transform") {
val labels = Array("a", "b", "c") val labels = Array("a", "b", "c")
val df0 = sqlContext.createDataFrame(Seq( val df0 = spark.createDataFrame(Seq(
(0, "a"), (1, "b"), (2, "c"), (0, "a") (0, "a"), (1, "b"), (2, "c"), (0, "a")
)).toDF("index", "expected") )).toDF("index", "expected")
@ -180,7 +180,7 @@ class StringIndexerSuite
test("StringIndexer, IndexToString are inverses") { test("StringIndexer, IndexToString are inverses") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label") val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")
@ -213,7 +213,7 @@ class StringIndexerSuite
test("StringIndexer metadata") { test("StringIndexer metadata") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label") val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer() val indexer = new StringIndexer()
.setInputCol("label") .setInputCol("label")
.setOutputCol("labelIndex") .setOutputCol("labelIndex")

View file

@ -57,13 +57,13 @@ class RegexTokenizerSuite
.setPattern("\\w+|\\p{Punct}") .setPattern("\\w+|\\p{Punct}")
.setInputCol("rawText") .setInputCol("rawText")
.setOutputCol("tokens") .setOutputCol("tokens")
val dataset0 = sqlContext.createDataFrame(Seq( val dataset0 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
)) ))
testRegexTokenizer(tokenizer0, dataset0) testRegexTokenizer(tokenizer0, dataset0)
val dataset1 = sqlContext.createDataFrame(Seq( val dataset1 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Array("punct")) TokenizerTestData("Te,st. punct", Array("punct"))
)) ))
@ -73,7 +73,7 @@ class RegexTokenizerSuite
val tokenizer2 = new RegexTokenizer() val tokenizer2 = new RegexTokenizer()
.setInputCol("rawText") .setInputCol("rawText")
.setOutputCol("tokens") .setOutputCol("tokens")
val dataset2 = sqlContext.createDataFrame(Seq( val dataset2 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) TokenizerTestData("Te,st. punct", Array("te,st.", "punct"))
)) ))
@ -85,7 +85,7 @@ class RegexTokenizerSuite
.setInputCol("rawText") .setInputCol("rawText")
.setOutputCol("tokens") .setOutputCol("tokens")
.setToLowercase(false) .setToLowercase(false)
val dataset = sqlContext.createDataFrame(Seq( val dataset = spark.createDataFrame(Seq(
TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
TokenizerTestData("java scala", Array("java", "scala")) TokenizerTestData("java scala", Array("java", "scala"))
)) ))

View file

@ -57,7 +57,7 @@ class VectorAssemblerSuite
} }
test("VectorAssembler") { test("VectorAssembler") {
val df = sqlContext.createDataFrame(Seq( val df = spark.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
)).toDF("id", "x", "y", "name", "z", "n") )).toDF("id", "x", "y", "name", "z", "n")
val assembler = new VectorAssembler() val assembler = new VectorAssembler()
@ -70,7 +70,7 @@ class VectorAssemblerSuite
} }
test("transform should throw an exception in case of unsupported type") { test("transform should throw an exception in case of unsupported type") {
val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
val assembler = new VectorAssembler() val assembler = new VectorAssembler()
.setInputCols(Array("a", "b", "c")) .setInputCols(Array("a", "b", "c"))
.setOutputCol("features") .setOutputCol("features")
@ -87,7 +87,7 @@ class VectorAssemblerSuite
NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
NumericAttribute.defaultAttr.withName("salary"))) NumericAttribute.defaultAttr.withName("salary")))
val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
.select( .select(
col("browser").as("browser", browser.toMetadata()), col("browser").as("browser", browser.toMetadata()),
col("hour").as("hour", hour.toMetadata()), col("hour").as("hour", hour.toMetadata()),

View file

@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints1Seq, sparsePoints1Seq)
checkPair(densePoints2Seq, sparsePoints2Seq) checkPair(densePoints2Seq, sparsePoints2Seq)
densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
} }
private def getIndexer: VectorIndexer = private def getIndexer: VectorIndexer =
@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
} }
test("Cannot fit an empty DataFrame") { test("Cannot fit an empty DataFrame") {
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer val vectorIndexer = getIndexer
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
vectorIndexer.fit(rdd) vectorIndexer.fit(rdd)

View file

@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
val df = sqlContext.createDataFrame(rdd, val df = spark.createDataFrame(rdd,
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")

View file

@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("Word2Vec") { test("Word2Vec") {
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10 val sentence = "a b " * 100 + "a c " * 10
val numOfWords = sentence.split(" ").size val numOfWords = sentence.split(" ").size
@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("getVectors") { test("getVectors") {
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10 val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") { test("findSynonyms") {
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10 val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("window size") { test("window size") {
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))

View file

@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -305,8 +305,8 @@ class ALSSuite
numUserBlocks: Int = 2, numUserBlocks: Int = 2,
numItemBlocks: Int = 3, numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = { targetRMSE: Double = 0.05): Unit = {
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val als = new ALS() val als = new ALS()
.setRank(rank) .setRank(rank)
.setRegParam(regParam) .setRegParam(regParam)
@ -460,8 +460,8 @@ class ALSSuite
allEstimatorParamSettings.foreach { case (p, v) => allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v) als.set(als.getParam(p), v)
} }
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val model = als.fit(ratings.toDF()) val model = als.fit(ratings.toDF())
// Test Estimator save/load // Test Estimator save/load
@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite {
// Generate test data // Generate test data
val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0)
// Implicitly test the cleaning of parents during ALS training // Implicitly test the cleaning of parents during ALS training
val sqlContext = new SQLContext(sc) val spark = SparkSession.builder
import sqlContext.implicits._ .master("local[2]")
.appName("ALSCleanerSuite")
.getOrCreate()
import spark.implicits._
val als = new ALS() val als = new ALS()
.setRank(1) .setRank(1)
.setRegParam(1e-5) .setRegParam(1e-5)
@ -577,8 +580,8 @@ class ALSStorageSuite
} }
test("default and non-default storage params set correct RDD StorageLevels") { test("default and non-default storage params set correct RDD StorageLevels") {
val sqlContext = this.sqlContext val spark = this.spark
import sqlContext.implicits._ import spark.implicits._
val data = Seq( val data = Seq(
(0, 0, 1.0), (0, 0, 1.0),
(0, 1, 2.0), (0, 1, 2.0),

Some files were not shown because too many files have changed in this diff Show more