[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Scala/Java TestSuites as this PR already very big working Python TestSuites in a diff PR. ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #12907 from techaddict/SPARK-15037.
This commit is contained in:
parent
bcfee153b1
commit
ed0b4070fb
|
@ -17,18 +17,18 @@
|
||||||
|
|
||||||
package org.apache.spark.ml;
|
package org.apache.spark.ml;
|
||||||
|
|
||||||
import org.apache.spark.sql.Dataset;
|
|
||||||
import org.apache.spark.sql.Row;
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
|
||||||
import org.apache.spark.ml.classification.LogisticRegression;
|
import org.apache.spark.ml.classification.LogisticRegression;
|
||||||
import org.apache.spark.ml.feature.StandardScaler;
|
import org.apache.spark.ml.feature.StandardScaler;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene
|
||||||
*/
|
*/
|
||||||
public class JavaPipelineSuite {
|
public class JavaPipelineSuite {
|
||||||
|
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
private transient Dataset<Row> dataset;
|
private transient Dataset<Row> dataset;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaPipelineSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaPipelineSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
JavaRDD<LabeledPoint> points =
|
JavaRDD<LabeledPoint> points =
|
||||||
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
|
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
|
||||||
dataset = jsql.createDataFrame(points, LabeledPoint.class);
|
dataset = spark.createDataFrame(points, LabeledPoint.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -63,10 +66,10 @@ public class JavaPipelineSuite {
|
||||||
LogisticRegression lr = new LogisticRegression()
|
LogisticRegression lr = new LogisticRegression()
|
||||||
.setFeaturesCol("scaledFeatures");
|
.setFeaturesCol("scaledFeatures");
|
||||||
Pipeline pipeline = new Pipeline()
|
Pipeline pipeline = new Pipeline()
|
||||||
.setStages(new PipelineStage[] {scaler, lr});
|
.setStages(new PipelineStage[]{scaler, lr});
|
||||||
PipelineModel model = pipeline.fit(dataset);
|
PipelineModel model = pipeline.fit(dataset);
|
||||||
model.transform(dataset).registerTempTable("prediction");
|
model.transform(dataset).registerTempTable("prediction");
|
||||||
Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
|
Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
|
||||||
predictions.collectAsList();
|
predictions.collectAsList();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.attribute;
|
package org.apache.spark.ml.attribute;
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
public class JavaAttributeSuite {
|
public class JavaAttributeSuite {
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,6 @@ import java.io.Serializable;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.apache.spark.sql.Dataset;
|
|
||||||
import org.apache.spark.sql.Row;
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.ml.tree.impl.TreeTests;
|
import org.apache.spark.ml.tree.impl.TreeTests;
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaDecisionTreeClassifierSuite implements Serializable {
|
public class JavaDecisionTreeClassifierSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaDecisionTreeClassifierSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
||||||
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
|
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
|
||||||
|
@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
|
||||||
.setCacheNodeIds(false)
|
.setCacheNodeIds(false)
|
||||||
.setCheckpointInterval(10)
|
.setCheckpointInterval(10)
|
||||||
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
||||||
for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
|
for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
|
||||||
dt.setImpurity(impurity);
|
dt.setImpurity(impurity);
|
||||||
}
|
}
|
||||||
DecisionTreeClassificationModel model = dt.fit(dataFrame);
|
DecisionTreeClassificationModel model = dt.fit(dataFrame);
|
||||||
|
|
|
@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaGBTClassifierSuite implements Serializable {
|
public class JavaGBTClassifierSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaGBTClassifierSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
||||||
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
|
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
|
||||||
|
@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable {
|
||||||
.setMaxIter(3)
|
.setMaxIter(3)
|
||||||
.setStepSize(0.1)
|
.setStepSize(0.1)
|
||||||
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
||||||
for (String lossType: GBTClassifier.supportedLossTypes()) {
|
for (String lossType : GBTClassifier.supportedLossTypes()) {
|
||||||
rf.setLossType(lossType);
|
rf.setLossType(lossType);
|
||||||
}
|
}
|
||||||
GBTClassificationModel model = rf.fit(dataFrame);
|
GBTClassificationModel model = rf.fit(dataFrame);
|
||||||
|
|
|
@ -27,18 +27,17 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
||||||
|
|
||||||
public class JavaLogisticRegressionSuite implements Serializable {
|
public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
|
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
private transient Dataset<Row> dataset;
|
private transient Dataset<Row> dataset;
|
||||||
|
|
||||||
private transient JavaRDD<LabeledPoint> datasetRDD;
|
private transient JavaRDD<LabeledPoint> datasetRDD;
|
||||||
|
@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaLogisticRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
|
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
|
||||||
datasetRDD = jsc.parallelize(points, 2);
|
datasetRDD = jsc.parallelize(points, 2);
|
||||||
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
|
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
|
||||||
dataset.registerTempTable("dataset");
|
dataset.registerTempTable("dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
Assert.assertEquals(lr.getLabelCol(), "label");
|
Assert.assertEquals(lr.getLabelCol(), "label");
|
||||||
LogisticRegressionModel model = lr.fit(dataset);
|
LogisticRegressionModel model = lr.fit(dataset);
|
||||||
model.transform(dataset).registerTempTable("prediction");
|
model.transform(dataset).registerTempTable("prediction");
|
||||||
Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
|
Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
|
||||||
predictions.collectAsList();
|
predictions.collectAsList();
|
||||||
// Check defaults
|
// Check defaults
|
||||||
Assert.assertEquals(0.5, model.getThreshold(), eps);
|
Assert.assertEquals(0.5, model.getThreshold(), eps);
|
||||||
|
@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
// Modify model params, and check that the params worked.
|
// Modify model params, and check that the params worked.
|
||||||
model.setThreshold(1.0);
|
model.setThreshold(1.0);
|
||||||
model.transform(dataset).registerTempTable("predAllZero");
|
model.transform(dataset).registerTempTable("predAllZero");
|
||||||
Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
|
Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
|
||||||
for (Row r: predAllZero.collectAsList()) {
|
for (Row r : predAllZero.collectAsList()) {
|
||||||
Assert.assertEquals(0.0, r.getDouble(0), eps);
|
Assert.assertEquals(0.0, r.getDouble(0), eps);
|
||||||
}
|
}
|
||||||
// Call transform with params, and check that the params worked.
|
// Call transform with params, and check that the params worked.
|
||||||
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
|
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
|
||||||
.registerTempTable("predNotAllZero");
|
.registerTempTable("predNotAllZero");
|
||||||
Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
|
Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
|
||||||
boolean foundNonZero = false;
|
boolean foundNonZero = false;
|
||||||
for (Row r: predNotAllZero.collectAsList()) {
|
for (Row r : predNotAllZero.collectAsList()) {
|
||||||
if (r.getDouble(0) != 0.0) foundNonZero = true;
|
if (r.getDouble(0) != 0.0) foundNonZero = true;
|
||||||
}
|
}
|
||||||
Assert.assertTrue(foundNonZero);
|
Assert.assertTrue(foundNonZero);
|
||||||
|
|
||||||
// Call fit() with new params, and check as many params as we can.
|
// Call fit() with new params, and check as many params as we can.
|
||||||
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
|
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
|
||||||
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
|
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
|
||||||
LogisticRegression parent2 = (LogisticRegression) model2.parent();
|
LogisticRegression parent2 = (LogisticRegression) model2.parent();
|
||||||
Assert.assertEquals(5, parent2.getMaxIter());
|
Assert.assertEquals(5, parent2.getMaxIter());
|
||||||
Assert.assertEquals(0.1, parent2.getRegParam(), eps);
|
Assert.assertEquals(0.1, parent2.getRegParam(), eps);
|
||||||
|
@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
Assert.assertEquals(2, model.numClasses());
|
Assert.assertEquals(2, model.numClasses());
|
||||||
|
|
||||||
model.transform(dataset).registerTempTable("transformed");
|
model.transform(dataset).registerTempTable("transformed");
|
||||||
Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
|
Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
|
||||||
for (Row row: trans1.collectAsList()) {
|
for (Row row : trans1.collectAsList()) {
|
||||||
Vector raw = (Vector)row.get(0);
|
Vector raw = (Vector) row.get(0);
|
||||||
Vector prob = (Vector)row.get(1);
|
Vector prob = (Vector) row.get(1);
|
||||||
Assert.assertEquals(raw.size(), 2);
|
Assert.assertEquals(raw.size(), 2);
|
||||||
Assert.assertEquals(prob.size(), 2);
|
Assert.assertEquals(prob.size(), 2);
|
||||||
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
|
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
|
||||||
|
@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
|
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
|
Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
|
||||||
for (Row row: trans2.collectAsList()) {
|
for (Row row : trans2.collectAsList()) {
|
||||||
double pred = row.getDouble(0);
|
double pred = row.getDouble(0);
|
||||||
Vector prob = (Vector)row.get(1);
|
Vector prob = (Vector) row.get(1);
|
||||||
double probOfPred = prob.apply((int)pred);
|
double probOfPred = prob.apply((int) pred);
|
||||||
for (int i = 0; i < prob.size(); ++i) {
|
for (int i = 0; i < prob.size(); ++i) {
|
||||||
Assert.assertTrue(probOfPred >= prob.apply(i));
|
Assert.assertTrue(probOfPred >= prob.apply(i));
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,49 +26,49 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
|
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext sqlContext;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
sqlContext = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaLogisticRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
sqlContext = null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMLPC() {
|
public void testMLPC() {
|
||||||
Dataset<Row> dataFrame = sqlContext.createDataFrame(
|
List<LabeledPoint> data = Arrays.asList(
|
||||||
jsc.parallelize(Arrays.asList(
|
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
|
||||||
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
|
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
|
||||||
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
|
new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
|
||||||
new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
|
new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
|
||||||
new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
|
);
|
||||||
LabeledPoint.class);
|
Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
|
||||||
|
|
||||||
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
|
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
|
||||||
.setLayers(new int[] {2, 5, 2})
|
.setLayers(new int[]{2, 5, 2})
|
||||||
.setBlockSize(1)
|
.setBlockSize(1)
|
||||||
.setSeed(123L)
|
.setSeed(123L)
|
||||||
.setMaxIter(100);
|
.setMaxIter(100);
|
||||||
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
|
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
|
||||||
Dataset<Row> result = model.transform(dataFrame);
|
Dataset<Row> result = model.transform(dataFrame);
|
||||||
List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
|
List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
|
||||||
for (Row r: predictionAndLabels) {
|
for (Row r : predictionAndLabels) {
|
||||||
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
|
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,13 +26,12 @@ import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.mllib.linalg.VectorUDT;
|
import org.apache.spark.mllib.linalg.VectorUDT;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.DataTypes;
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
public class JavaNaiveBayesSuite implements Serializable {
|
public class JavaNaiveBayesSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaLogisticRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void validatePrediction(Dataset<Row> predictionAndLabels) {
|
public void validatePrediction(Dataset<Row> predictionAndLabels) {
|
||||||
|
@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable {
|
||||||
new StructField("features", new VectorUDT(), false, Metadata.empty())
|
new StructField("features", new VectorUDT(), false, Metadata.empty())
|
||||||
});
|
});
|
||||||
|
|
||||||
Dataset<Row> dataset = jsql.createDataFrame(data, schema);
|
Dataset<Row> dataset = spark.createDataFrame(data, schema);
|
||||||
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
|
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
|
||||||
NaiveBayesModel model = nb.fit(dataset);
|
NaiveBayesModel model = nb.fit(dataset);
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.ml.classification;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import org.apache.spark.sql.Row;
|
|
||||||
import scala.collection.JavaConverters;
|
import scala.collection.JavaConverters;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -30,56 +29,61 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
|
||||||
|
|
||||||
public class JavaOneVsRestSuite implements Serializable {
|
public class JavaOneVsRestSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
private transient JavaSparkContext jsc;
|
||||||
private transient Dataset<Row> dataset;
|
private transient Dataset<Row> dataset;
|
||||||
private transient JavaRDD<LabeledPoint> datasetRDD;
|
private transient JavaRDD<LabeledPoint> datasetRDD;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
int nPoints = 3;
|
.appName("JavaLOneVsRestSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
// The following coefficients and xMean/xVariance are computed from iris dataset with
|
int nPoints = 3;
|
||||||
// lambda=0.2.
|
|
||||||
// As a result, we are drawing samples from probability distribution of an actual model.
|
|
||||||
double[] coefficients = {
|
|
||||||
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
|
|
||||||
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
|
|
||||||
|
|
||||||
double[] xMean = {5.843, 3.057, 3.758, 1.199};
|
// The following coefficients and xMean/xVariance are computed from iris dataset with
|
||||||
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
|
// lambda=0.2.
|
||||||
List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
|
// As a result, we are drawing samples from probability distribution of an actual model.
|
||||||
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
|
double[] coefficients = {
|
||||||
).asJava();
|
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
|
||||||
datasetRDD = jsc.parallelize(points, 2);
|
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
|
||||||
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
double[] xMean = {5.843, 3.057, 3.758, 1.199};
|
||||||
public void tearDown() {
|
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
|
||||||
jsc.stop();
|
List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
|
||||||
jsc = null;
|
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
|
||||||
}
|
).asJava();
|
||||||
|
datasetRDD = jsc.parallelize(points, 2);
|
||||||
|
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@After
|
||||||
public void oneVsRestDefaultParams() {
|
public void tearDown() {
|
||||||
OneVsRest ova = new OneVsRest();
|
spark.stop();
|
||||||
ova.setClassifier(new LogisticRegression());
|
spark = null;
|
||||||
Assert.assertEquals(ova.getLabelCol() , "label");
|
}
|
||||||
Assert.assertEquals(ova.getPredictionCol() , "prediction");
|
|
||||||
OneVsRestModel ovaModel = ova.fit(dataset);
|
@Test
|
||||||
Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
|
public void oneVsRestDefaultParams() {
|
||||||
predictions.collectAsList();
|
OneVsRest ova = new OneVsRest();
|
||||||
Assert.assertEquals(ovaModel.getLabelCol(), "label");
|
ova.setClassifier(new LogisticRegression());
|
||||||
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
|
Assert.assertEquals(ova.getLabelCol(), "label");
|
||||||
}
|
Assert.assertEquals(ova.getPredictionCol(), "prediction");
|
||||||
|
OneVsRestModel ovaModel = ova.fit(dataset);
|
||||||
|
Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
|
||||||
|
predictions.collectAsList();
|
||||||
|
Assert.assertEquals(ovaModel.getLabelCol(), "label");
|
||||||
|
Assert.assertEquals(ovaModel.getPredictionCol(), "prediction");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaRandomForestClassifierSuite implements Serializable {
|
public class JavaRandomForestClassifierSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaRandomForestClassifierSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
||||||
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
|
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
|
||||||
|
@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable {
|
||||||
.setSeed(1234)
|
.setSeed(1234)
|
||||||
.setNumTrees(3)
|
.setNumTrees(3)
|
||||||
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
||||||
for (String impurity: RandomForestClassifier.supportedImpurities()) {
|
for (String impurity : RandomForestClassifier.supportedImpurities()) {
|
||||||
rf.setImpurity(impurity);
|
rf.setImpurity(impurity);
|
||||||
}
|
}
|
||||||
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
|
for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
|
||||||
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
|
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
|
||||||
}
|
}
|
||||||
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
|
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
|
||||||
for (String strategy: realStrategies) {
|
for (String strategy : realStrategies) {
|
||||||
rf.setFeatureSubsetStrategy(strategy);
|
rf.setFeatureSubsetStrategy(strategy);
|
||||||
}
|
}
|
||||||
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
|
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
|
||||||
for (String strategy: integerStrategies) {
|
for (String strategy : integerStrategies) {
|
||||||
rf.setFeatureSubsetStrategy(strategy);
|
rf.setFeatureSubsetStrategy(strategy);
|
||||||
}
|
}
|
||||||
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
|
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
|
||||||
for (String strategy: invalidStrategies) {
|
for (String strategy : invalidStrategies) {
|
||||||
try {
|
try {
|
||||||
rf.setFeatureSubsetStrategy(strategy);
|
rf.setFeatureSubsetStrategy(strategy);
|
||||||
Assert.fail("Expected exception to be thrown for invalid strategies");
|
Assert.fail("Expected exception to be thrown for invalid strategies");
|
||||||
|
|
|
@ -21,37 +21,37 @@ import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import org.junit.After;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaKMeansSuite implements Serializable {
|
public class JavaKMeansSuite implements Serializable {
|
||||||
|
|
||||||
private transient int k = 5;
|
private transient int k = 5;
|
||||||
private transient JavaSparkContext sc;
|
|
||||||
private transient Dataset<Row> dataset;
|
private transient Dataset<Row> dataset;
|
||||||
private transient SQLContext sql;
|
private transient SparkSession spark;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaKMeansSuite");
|
spark = SparkSession.builder()
|
||||||
sql = new SQLContext(sc);
|
.master("local")
|
||||||
|
.appName("JavaKMeansSuite")
|
||||||
dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
|
.getOrCreate();
|
||||||
|
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable {
|
||||||
Dataset<Row> transformed = model.transform(dataset);
|
Dataset<Row> transformed = model.transform(dataset);
|
||||||
List<String> columns = Arrays.asList(transformed.columns());
|
List<String> columns = Arrays.asList(transformed.columns());
|
||||||
List<String> expectedColumns = Arrays.asList("features", "prediction");
|
List<String> expectedColumns = Arrays.asList("features", "prediction");
|
||||||
for (String column: expectedColumns) {
|
for (String column : expectedColumns) {
|
||||||
assertTrue(columns.contains(column));
|
assertTrue(columns.contains(column));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,40 +25,40 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.DataTypes;
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
import org.apache.spark.sql.types.StructType;
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
public class JavaBucketizerSuite {
|
public class JavaBucketizerSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaBucketizerSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void bucketizerTest() {
|
public void bucketizerTest() {
|
||||||
double[] splits = {-0.5, 0.0, 0.5};
|
double[] splits = {-0.5, 0.0, 0.5};
|
||||||
|
|
||||||
StructType schema = new StructType(new StructField[] {
|
StructType schema = new StructType(new StructField[]{
|
||||||
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
|
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
|
||||||
});
|
});
|
||||||
Dataset<Row> dataset = jsql.createDataFrame(
|
Dataset<Row> dataset = spark.createDataFrame(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
RowFactory.create(-0.5),
|
RowFactory.create(-0.5),
|
||||||
RowFactory.create(-0.3),
|
RowFactory.create(-0.3),
|
||||||
|
|
|
@ -21,43 +21,44 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
|
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.VectorUDT;
|
import org.apache.spark.mllib.linalg.VectorUDT;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
import org.apache.spark.sql.types.StructType;
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
public class JavaDCTSuite {
|
public class JavaDCTSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaDCTSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaDCTSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void javaCompatibilityTest() {
|
public void javaCompatibilityTest() {
|
||||||
double[] input = new double[] {1D, 2D, 3D, 4D};
|
double[] input = new double[]{1D, 2D, 3D, 4D};
|
||||||
Dataset<Row> dataset = jsql.createDataFrame(
|
Dataset<Row> dataset = spark.createDataFrame(
|
||||||
Arrays.asList(RowFactory.create(Vectors.dense(input))),
|
Arrays.asList(RowFactory.create(Vectors.dense(input))),
|
||||||
new StructType(new StructField[]{
|
new StructType(new StructField[]{
|
||||||
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
|
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
|
||||||
|
|
|
@ -25,12 +25,11 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.DataTypes;
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
|
|
||||||
public class JavaHashingTFSuite {
|
public class JavaHashingTFSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaHashingTFSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaHashingTFSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -65,7 +65,7 @@ public class JavaHashingTFSuite {
|
||||||
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
|
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
|
||||||
});
|
});
|
||||||
|
|
||||||
Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
|
Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
|
||||||
Tokenizer tokenizer = new Tokenizer()
|
Tokenizer tokenizer = new Tokenizer()
|
||||||
.setInputCol("sentence")
|
.setInputCol("sentence")
|
||||||
.setOutputCol("words");
|
.setOutputCol("words");
|
||||||
|
|
|
@ -23,27 +23,30 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaNormalizerSuite {
|
public class JavaNormalizerSuite {
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaNormalizerSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -54,7 +57,7 @@ public class JavaNormalizerSuite {
|
||||||
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
|
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
|
||||||
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
|
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
|
||||||
));
|
));
|
||||||
Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
|
Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
|
||||||
Normalizer normalizer = new Normalizer()
|
Normalizer normalizer = new Normalizer()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
.setOutputCol("normFeatures");
|
.setOutputCol("normFeatures");
|
||||||
|
|
|
@ -28,31 +28,34 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.function.Function;
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
|
import org.apache.spark.api.java.function.Function;
|
||||||
import org.apache.spark.mllib.linalg.Matrix;
|
import org.apache.spark.mllib.linalg.Matrix;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaPCASuite implements Serializable {
|
public class JavaPCASuite implements Serializable {
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext sqlContext;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaPCASuite");
|
spark = SparkSession.builder()
|
||||||
sqlContext = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaPCASuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class VectorPair implements Serializable {
|
public static class VectorPair implements Serializable {
|
||||||
|
@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
|
Dataset<Row> df = spark.createDataFrame(featuresExpected, VectorPair.class);
|
||||||
PCAModel pca = new PCA()
|
PCAModel pca = new PCA()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
.setOutputCol("pca_features")
|
.setOutputCol("pca_features")
|
||||||
|
|
|
@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
import org.apache.spark.sql.types.StructType;
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
public class JavaPolynomialExpansionSuite {
|
public class JavaPolynomialExpansionSuite {
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaPolynomialExpansionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
|
@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite {
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
StructType schema = new StructType(new StructField[] {
|
StructType schema = new StructType(new StructField[]{
|
||||||
new StructField("features", new VectorUDT(), false, Metadata.empty()),
|
new StructField("features", new VectorUDT(), false, Metadata.empty()),
|
||||||
new StructField("expected", new VectorUDT(), false, Metadata.empty())
|
new StructField("expected", new VectorUDT(), false, Metadata.empty())
|
||||||
});
|
});
|
||||||
|
|
||||||
Dataset<Row> dataset = jsql.createDataFrame(data, schema);
|
Dataset<Row> dataset = spark.createDataFrame(data, schema);
|
||||||
|
|
||||||
List<Row> pairs = polyExpansion.transform(dataset)
|
List<Row> pairs = polyExpansion.transform(dataset)
|
||||||
.select("polyFeatures", "expected")
|
.select("polyFeatures", "expected")
|
||||||
.collectAsList();
|
.collectAsList();
|
||||||
|
|
||||||
for (Row r : pairs) {
|
for (Row r : pairs) {
|
||||||
double[] polyFeatures = ((Vector)r.get(0)).toArray();
|
double[] polyFeatures = ((Vector) r.get(0)).toArray();
|
||||||
double[] expected = ((Vector)r.get(1)).toArray();
|
double[] expected = ((Vector) r.get(1)).toArray();
|
||||||
Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
|
Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaStandardScalerSuite {
|
public class JavaStandardScalerSuite {
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaStandardScalerSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -54,7 +57,7 @@ public class JavaStandardScalerSuite {
|
||||||
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
|
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
|
||||||
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
|
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
|
||||||
);
|
);
|
||||||
Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
|
Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
|
||||||
VectorIndexerSuite.FeatureData.class);
|
VectorIndexerSuite.FeatureData.class);
|
||||||
StandardScaler scaler = new StandardScaler()
|
StandardScaler scaler = new StandardScaler()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
|
|
@ -24,11 +24,10 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.DataTypes;
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
public class JavaStopWordsRemoverSuite {
|
public class JavaStopWordsRemoverSuite {
|
||||||
|
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaStopWordsRemoverSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite {
|
||||||
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
|
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
|
||||||
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
|
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
|
||||||
);
|
);
|
||||||
StructType schema = new StructType(new StructField[] {
|
StructType schema = new StructType(new StructField[]{
|
||||||
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
|
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
|
||||||
Metadata.empty())
|
Metadata.empty())
|
||||||
});
|
});
|
||||||
Dataset<Row> dataset = jsql.createDataFrame(data, schema);
|
Dataset<Row> dataset = spark.createDataFrame(data, schema);
|
||||||
|
|
||||||
remover.transform(dataset).collect();
|
remover.transform(dataset).collect();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,40 +25,42 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
import org.apache.spark.sql.types.StructType;
|
import org.apache.spark.sql.types.StructType;
|
||||||
import static org.apache.spark.sql.types.DataTypes.*;
|
import static org.apache.spark.sql.types.DataTypes.*;
|
||||||
|
|
||||||
public class JavaStringIndexerSuite {
|
public class JavaStringIndexerSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext sqlContext;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
|
SparkConf sparkConf = new SparkConf();
|
||||||
sqlContext = new SQLContext(jsc);
|
sparkConf.setMaster("local");
|
||||||
|
sparkConf.setAppName("JavaStringIndexerSuite");
|
||||||
|
|
||||||
|
spark = SparkSession.builder().config(sparkConf).getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
sqlContext = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStringIndexer() {
|
public void testStringIndexer() {
|
||||||
StructType schema = createStructType(new StructField[] {
|
StructType schema = createStructType(new StructField[]{
|
||||||
createStructField("id", IntegerType, false),
|
createStructField("id", IntegerType, false),
|
||||||
createStructField("label", StringType, false)
|
createStructField("label", StringType, false)
|
||||||
});
|
});
|
||||||
List<Row> data = Arrays.asList(
|
List<Row> data = Arrays.asList(
|
||||||
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
|
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
|
||||||
Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
|
Dataset<Row> dataset = spark.createDataFrame(data, schema);
|
||||||
|
|
||||||
StringIndexer indexer = new StringIndexer()
|
StringIndexer indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
|
@ -70,7 +72,9 @@ public class JavaStringIndexerSuite {
|
||||||
output.orderBy("id").select("id", "labelIndex").collectAsList());
|
output.orderBy("id").select("id", "labelIndex").collectAsList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** An alias for RowFactory.create. */
|
/**
|
||||||
|
* An alias for RowFactory.create.
|
||||||
|
*/
|
||||||
private Row cr(Object... values) {
|
private Row cr(Object... values) {
|
||||||
return RowFactory.create(values);
|
return RowFactory.create(values);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaTokenizerSuite {
|
public class JavaTokenizerSuite {
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaTokenizerSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -59,10 +62,10 @@ public class JavaTokenizerSuite {
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
|
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
|
||||||
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
|
new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}),
|
||||||
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
|
new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"})
|
||||||
));
|
));
|
||||||
Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
|
Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class);
|
||||||
|
|
||||||
List<Row> pairs = myRegExTokenizer.transform(dataset)
|
List<Row> pairs = myRegExTokenizer.transform(dataset)
|
||||||
.select("tokens", "wantedTokens")
|
.select("tokens", "wantedTokens")
|
||||||
|
|
|
@ -24,36 +24,39 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.VectorUDT;
|
import org.apache.spark.mllib.linalg.VectorUDT;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.*;
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
import static org.apache.spark.sql.types.DataTypes.*;
|
import static org.apache.spark.sql.types.DataTypes.*;
|
||||||
|
|
||||||
public class JavaVectorAssemblerSuite {
|
public class JavaVectorAssemblerSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext sqlContext;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
|
SparkConf sparkConf = new SparkConf();
|
||||||
sqlContext = new SQLContext(jsc);
|
sparkConf.setMaster("local");
|
||||||
|
sparkConf.setAppName("JavaVectorAssemblerSuite");
|
||||||
|
|
||||||
|
spark = SparkSession.builder().config(sparkConf).getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVectorAssembler() {
|
public void testVectorAssembler() {
|
||||||
StructType schema = createStructType(new StructField[] {
|
StructType schema = createStructType(new StructField[]{
|
||||||
createStructField("id", IntegerType, false),
|
createStructField("id", IntegerType, false),
|
||||||
createStructField("x", DoubleType, false),
|
createStructField("x", DoubleType, false),
|
||||||
createStructField("y", new VectorUDT(), false),
|
createStructField("y", new VectorUDT(), false),
|
||||||
|
@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite {
|
||||||
});
|
});
|
||||||
Row row = RowFactory.create(
|
Row row = RowFactory.create(
|
||||||
0, 0.0, Vectors.dense(1.0, 2.0), "a",
|
0, 0.0, Vectors.dense(1.0, 2.0), "a",
|
||||||
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
|
Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
|
||||||
Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
|
Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
|
||||||
VectorAssembler assembler = new VectorAssembler()
|
VectorAssembler assembler = new VectorAssembler()
|
||||||
.setInputCols(new String[] {"x", "y", "z", "n"})
|
.setInputCols(new String[]{"x", "y", "z", "n"})
|
||||||
.setOutputCol("features");
|
.setOutputCol("features");
|
||||||
Dataset<Row> output = assembler.transform(dataset);
|
Dataset<Row> output = assembler.transform(dataset);
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
|
Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}),
|
||||||
output.select("features").first().<Vector>getAs(0));
|
output.select("features").first().<Vector>getAs(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaVectorIndexerSuite implements Serializable {
|
public class JavaVectorIndexerSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaVectorIndexerSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable {
|
||||||
new FeatureData(Vectors.dense(1.0, 3.0)),
|
new FeatureData(Vectors.dense(1.0, 3.0)),
|
||||||
new FeatureData(Vectors.dense(1.0, 4.0))
|
new FeatureData(Vectors.dense(1.0, 4.0))
|
||||||
);
|
);
|
||||||
SQLContext sqlContext = new SQLContext(sc);
|
Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class);
|
||||||
Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
|
|
||||||
VectorIndexer indexer = new VectorIndexer()
|
VectorIndexer indexer = new VectorIndexer()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
.setOutputCol("indexed")
|
.setOutputCol("indexed")
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.ml.attribute.Attribute;
|
import org.apache.spark.ml.attribute.Attribute;
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup;
|
import org.apache.spark.ml.attribute.AttributeGroup;
|
||||||
import org.apache.spark.ml.attribute.NumericAttribute;
|
import org.apache.spark.ml.attribute.NumericAttribute;
|
||||||
|
@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.StructType;
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
|
|
||||||
public class JavaVectorSlicerSuite {
|
public class JavaVectorSlicerSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext jsql;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaVectorSlicerSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite {
|
||||||
);
|
);
|
||||||
|
|
||||||
Dataset<Row> dataset =
|
Dataset<Row> dataset =
|
||||||
jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
|
spark.createDataFrame(data, (new StructType()).add(group.toStructField()));
|
||||||
|
|
||||||
VectorSlicer vectorSlicer = new VectorSlicer()
|
VectorSlicer vectorSlicer = new VectorSlicer()
|
||||||
.setInputCol("userFeatures").setOutputCol("features");
|
.setInputCol("userFeatures").setOutputCol("features");
|
||||||
|
|
|
@ -24,28 +24,28 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.RowFactory;
|
import org.apache.spark.sql.RowFactory;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.sql.types.*;
|
import org.apache.spark.sql.types.*;
|
||||||
|
|
||||||
public class JavaWord2VecSuite {
|
public class JavaWord2VecSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext sqlContext;
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
|
spark = SparkSession.builder()
|
||||||
sqlContext = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaWord2VecSuite")
|
||||||
|
.getOrCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
|
||||||
StructType schema = new StructType(new StructField[]{
|
StructType schema = new StructType(new StructField[]{
|
||||||
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
|
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
|
||||||
});
|
});
|
||||||
Dataset<Row> documentDF = sqlContext.createDataFrame(
|
Dataset<Row> documentDF = spark.createDataFrame(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
|
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
|
||||||
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
|
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
|
||||||
|
@ -68,8 +68,8 @@ public class JavaWord2VecSuite {
|
||||||
Word2VecModel model = word2Vec.fit(documentDF);
|
Word2VecModel model = word2Vec.fit(documentDF);
|
||||||
Dataset<Row> result = model.transform(documentDF);
|
Dataset<Row> result = model.transform(documentDF);
|
||||||
|
|
||||||
for (Row r: result.select("result").collectAsList()) {
|
for (Row r : result.select("result").collectAsList()) {
|
||||||
double[] polyFeatures = ((Vector)r.get(0)).toArray();
|
double[] polyFeatures = ((Vector) r.get(0)).toArray();
|
||||||
Assert.assertEquals(polyFeatures.length, 3);
|
Assert.assertEquals(polyFeatures.length, 3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,23 +25,29 @@ import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test Param and related classes in Java
|
* Test Param and related classes in Java
|
||||||
*/
|
*/
|
||||||
public class JavaParamsSuite {
|
public class JavaParamsSuite {
|
||||||
|
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaParamsSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaParamsSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -51,7 +57,7 @@ public class JavaParamsSuite {
|
||||||
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
|
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
|
||||||
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
|
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
|
||||||
Assert.assertEquals(testParams.getMyStringParam(), "a");
|
Assert.assertEquals(testParams.getMyStringParam(), "a");
|
||||||
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
|
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
private IntParam myIntParam_;
|
private IntParam myIntParam_;
|
||||||
public IntParam myIntParam() { return myIntParam_; }
|
|
||||||
|
|
||||||
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
|
public IntParam myIntParam() {
|
||||||
|
return myIntParam_;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMyIntParam() {
|
||||||
|
return (Integer) getOrDefault(myIntParam_);
|
||||||
|
}
|
||||||
|
|
||||||
public JavaTestParams setMyIntParam(int value) {
|
public JavaTestParams setMyIntParam(int value) {
|
||||||
set(myIntParam_, value);
|
set(myIntParam_, value);
|
||||||
|
@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
private DoubleParam myDoubleParam_;
|
private DoubleParam myDoubleParam_;
|
||||||
public DoubleParam myDoubleParam() { return myDoubleParam_; }
|
|
||||||
|
|
||||||
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
|
public DoubleParam myDoubleParam() {
|
||||||
|
return myDoubleParam_;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getMyDoubleParam() {
|
||||||
|
return (Double) getOrDefault(myDoubleParam_);
|
||||||
|
}
|
||||||
|
|
||||||
public JavaTestParams setMyDoubleParam(double value) {
|
public JavaTestParams setMyDoubleParam(double value) {
|
||||||
set(myDoubleParam_, value);
|
set(myDoubleParam_, value);
|
||||||
|
@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Param<String> myStringParam_;
|
private Param<String> myStringParam_;
|
||||||
public Param<String> myStringParam() { return myStringParam_; }
|
|
||||||
|
|
||||||
public String getMyStringParam() { return getOrDefault(myStringParam_); }
|
public Param<String> myStringParam() {
|
||||||
|
return myStringParam_;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMyStringParam() {
|
||||||
|
return getOrDefault(myStringParam_);
|
||||||
|
}
|
||||||
|
|
||||||
public JavaTestParams setMyStringParam(String value) {
|
public JavaTestParams setMyStringParam(String value) {
|
||||||
set(myStringParam_, value);
|
set(myStringParam_, value);
|
||||||
|
@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
private DoubleArrayParam myDoubleArrayParam_;
|
private DoubleArrayParam myDoubleArrayParam_;
|
||||||
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
|
|
||||||
|
|
||||||
public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
|
public DoubleArrayParam myDoubleArrayParam() {
|
||||||
|
return myDoubleArrayParam_;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getMyDoubleArrayParam() {
|
||||||
|
return getOrDefault(myDoubleArrayParam_);
|
||||||
|
}
|
||||||
|
|
||||||
public JavaTestParams setMyDoubleArrayParam(double[] value) {
|
public JavaTestParams setMyDoubleArrayParam(double[] value) {
|
||||||
set(myDoubleArrayParam_, value);
|
set(myDoubleArrayParam_, value);
|
||||||
|
@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams {
|
||||||
|
|
||||||
setDefault(myIntParam(), 1);
|
setDefault(myIntParam(), 1);
|
||||||
setDefault(myDoubleParam(), 0.5);
|
setDefault(myDoubleParam(), 0.5);
|
||||||
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
|
setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaDecisionTreeRegressorSuite implements Serializable {
|
public class JavaDecisionTreeRegressorSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaDecisionTreeRegressorSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
||||||
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
|
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
|
||||||
|
@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
|
||||||
.setCacheNodeIds(false)
|
.setCacheNodeIds(false)
|
||||||
.setCheckpointInterval(10)
|
.setCheckpointInterval(10)
|
||||||
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
||||||
for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
|
for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
|
||||||
dt.setImpurity(impurity);
|
dt.setImpurity(impurity);
|
||||||
}
|
}
|
||||||
DecisionTreeRegressionModel model = dt.fit(dataFrame);
|
DecisionTreeRegressionModel model = dt.fit(dataFrame);
|
||||||
|
|
|
@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaGBTRegressorSuite implements Serializable {
|
public class JavaGBTRegressorSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaGBTRegressorSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
||||||
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
|
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
|
||||||
|
@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable {
|
||||||
.setMaxIter(3)
|
.setMaxIter(3)
|
||||||
.setStepSize(0.1)
|
.setStepSize(0.1)
|
||||||
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
||||||
for (String lossType: GBTRegressor.supportedLossTypes()) {
|
for (String lossType : GBTRegressor.supportedLossTypes()) {
|
||||||
rf.setLossType(lossType);
|
rf.setLossType(lossType);
|
||||||
}
|
}
|
||||||
GBTRegressionModel model = rf.fit(dataFrame);
|
GBTRegressionModel model = rf.fit(dataFrame);
|
||||||
|
|
|
@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
|
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
||||||
.generateLogisticInputAsList;
|
|
||||||
|
|
||||||
|
|
||||||
public class JavaLinearRegressionSuite implements Serializable {
|
public class JavaLinearRegressionSuite implements Serializable {
|
||||||
|
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
private transient Dataset<Row> dataset;
|
private transient Dataset<Row> dataset;
|
||||||
private transient JavaRDD<LabeledPoint> datasetRDD;
|
private transient JavaRDD<LabeledPoint> datasetRDD;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaLinearRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
|
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
|
||||||
datasetRDD = jsc.parallelize(points, 2);
|
datasetRDD = jsc.parallelize(points, 2);
|
||||||
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
|
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
|
||||||
dataset.registerTempTable("dataset");
|
dataset.registerTempTable("dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable {
|
||||||
assertEquals("auto", lr.getSolver());
|
assertEquals("auto", lr.getSolver());
|
||||||
LinearRegressionModel model = lr.fit(dataset);
|
LinearRegressionModel model = lr.fit(dataset);
|
||||||
model.transform(dataset).registerTempTable("prediction");
|
model.transform(dataset).registerTempTable("prediction");
|
||||||
Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
|
Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction");
|
||||||
predictions.collect();
|
predictions.collect();
|
||||||
// Check defaults
|
// Check defaults
|
||||||
assertEquals("features", model.getFeaturesCol());
|
assertEquals("features", model.getFeaturesCol());
|
||||||
|
@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable {
|
||||||
public void linearRegressionWithSetters() {
|
public void linearRegressionWithSetters() {
|
||||||
// Set params, train, and check as many params as we can.
|
// Set params, train, and check as many params as we can.
|
||||||
LinearRegression lr = new LinearRegression()
|
LinearRegression lr = new LinearRegression()
|
||||||
.setMaxIter(10)
|
.setMaxIter(10)
|
||||||
.setRegParam(1.0).setSolver("l-bfgs");
|
.setRegParam(1.0).setSolver("l-bfgs");
|
||||||
LinearRegressionModel model = lr.fit(dataset);
|
LinearRegressionModel model = lr.fit(dataset);
|
||||||
LinearRegression parent = (LinearRegression) model.parent();
|
LinearRegression parent = (LinearRegression) model.parent();
|
||||||
assertEquals(10, parent.getMaxIter());
|
assertEquals(10, parent.getMaxIter());
|
||||||
|
@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable {
|
||||||
|
|
||||||
// Call fit() with new params, and check as many params as we can.
|
// Call fit() with new params, and check as many params as we can.
|
||||||
LinearRegressionModel model2 =
|
LinearRegressionModel model2 =
|
||||||
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
|
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
|
||||||
LinearRegression parent2 = (LinearRegression) model2.parent();
|
LinearRegression parent2 = (LinearRegression) model2.parent();
|
||||||
assertEquals(5, parent2.getMaxIter());
|
assertEquals(5, parent2.getMaxIter());
|
||||||
assertEquals(0.1, parent2.getRegParam(), 0.0);
|
assertEquals(0.1, parent2.getRegParam(), 0.0);
|
||||||
|
|
|
@ -28,27 +28,33 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
|
||||||
import org.apache.spark.ml.tree.impl.TreeTests;
|
import org.apache.spark.ml.tree.impl.TreeTests;
|
||||||
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaRandomForestRegressorSuite implements Serializable {
|
public class JavaRandomForestRegressorSuite implements Serializable {
|
||||||
|
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaRandomForestRegressorSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
|
||||||
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
|
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
|
||||||
|
@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable {
|
||||||
.setSeed(1234)
|
.setSeed(1234)
|
||||||
.setNumTrees(3)
|
.setNumTrees(3)
|
||||||
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
|
||||||
for (String impurity: RandomForestRegressor.supportedImpurities()) {
|
for (String impurity : RandomForestRegressor.supportedImpurities()) {
|
||||||
rf.setImpurity(impurity);
|
rf.setImpurity(impurity);
|
||||||
}
|
}
|
||||||
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
|
for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
|
||||||
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
|
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
|
||||||
}
|
}
|
||||||
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
|
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
|
||||||
for (String strategy: realStrategies) {
|
for (String strategy : realStrategies) {
|
||||||
rf.setFeatureSubsetStrategy(strategy);
|
rf.setFeatureSubsetStrategy(strategy);
|
||||||
}
|
}
|
||||||
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
|
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
|
||||||
for (String strategy: integerStrategies) {
|
for (String strategy : integerStrategies) {
|
||||||
rf.setFeatureSubsetStrategy(strategy);
|
rf.setFeatureSubsetStrategy(strategy);
|
||||||
}
|
}
|
||||||
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
|
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
|
||||||
for (String strategy: invalidStrategies) {
|
for (String strategy : invalidStrategies) {
|
||||||
try {
|
try {
|
||||||
rf.setFeatureSubsetStrategy(strategy);
|
rf.setFeatureSubsetStrategy(strategy);
|
||||||
Assert.fail("Expected exception to be thrown for invalid strategies");
|
Assert.fail("Expected exception to be thrown for invalid strategies");
|
||||||
|
|
|
@ -28,12 +28,11 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.mllib.linalg.DenseVector;
|
import org.apache.spark.mllib.linalg.DenseVector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.util.Utils;
|
import org.apache.spark.util.Utils;
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,16 +40,17 @@ import org.apache.spark.util.Utils;
|
||||||
* Test LibSVMRelation in Java.
|
* Test LibSVMRelation in Java.
|
||||||
*/
|
*/
|
||||||
public class JavaLibSVMRelationSuite {
|
public class JavaLibSVMRelationSuite {
|
||||||
private transient JavaSparkContext jsc;
|
private transient SparkSession spark;
|
||||||
private transient SQLContext sqlContext;
|
|
||||||
|
|
||||||
private File tempDir;
|
private File tempDir;
|
||||||
private String path;
|
private String path;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() throws IOException {
|
public void setUp() throws IOException {
|
||||||
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
|
spark = SparkSession.builder()
|
||||||
sqlContext = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaLibSVMRelationSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
|
||||||
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
|
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
|
||||||
File file = new File(tempDir, "part-00000");
|
File file = new File(tempDir, "part-00000");
|
||||||
|
@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
Utils.deleteRecursively(tempDir);
|
Utils.deleteRecursively(tempDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void verifyLibSVMDF() {
|
public void verifyLibSVMDF() {
|
||||||
Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
|
Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense")
|
||||||
.load(path);
|
.load(path);
|
||||||
Assert.assertEquals("label", dataset.columns()[0]);
|
Assert.assertEquals("label", dataset.columns()[0]);
|
||||||
Assert.assertEquals("features", dataset.columns()[1]);
|
Assert.assertEquals("features", dataset.columns()[1]);
|
||||||
|
|
|
@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
|
||||||
|
|
||||||
public class JavaCrossValidatorSuite implements Serializable {
|
public class JavaCrossValidatorSuite implements Serializable {
|
||||||
|
|
||||||
|
private transient SparkSession spark;
|
||||||
private transient JavaSparkContext jsc;
|
private transient JavaSparkContext jsc;
|
||||||
private transient SQLContext jsql;
|
|
||||||
private transient Dataset<Row> dataset;
|
private transient Dataset<Row> dataset;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
|
spark = SparkSession.builder()
|
||||||
jsql = new SQLContext(jsc);
|
.master("local")
|
||||||
|
.appName("JavaCrossValidatorSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
|
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
|
||||||
dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
|
dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
|
@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable {
|
||||||
public void crossValidationWithLogisticRegression() {
|
public void crossValidationWithLogisticRegression() {
|
||||||
LogisticRegression lr = new LogisticRegression();
|
LogisticRegression lr = new LogisticRegression();
|
||||||
ParamMap[] lrParamMaps = new ParamGridBuilder()
|
ParamMap[] lrParamMaps = new ParamGridBuilder()
|
||||||
.addGrid(lr.regParam(), new double[] {0.001, 1000.0})
|
.addGrid(lr.regParam(), new double[]{0.001, 1000.0})
|
||||||
.addGrid(lr.maxIter(), new int[] {0, 10})
|
.addGrid(lr.maxIter(), new int[]{0, 10})
|
||||||
.build();
|
.build();
|
||||||
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
|
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
|
||||||
CrossValidator cv = new CrossValidator()
|
CrossValidator cv = new CrossValidator()
|
||||||
|
|
|
@ -37,4 +37,5 @@ object IdentifiableSuite {
|
||||||
class Test(override val uid: String) extends Identifiable {
|
class Test(override val uid: String) extends Identifiable {
|
||||||
def this() = this(Identifiable.randomUID("test"))
|
def this() = this(Identifiable.randomUID("test"))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,31 +27,34 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.sql.SQLContext;
|
import org.apache.spark.sql.SQLContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.util.Utils;
|
import org.apache.spark.util.Utils;
|
||||||
|
|
||||||
public class JavaDefaultReadWriteSuite {
|
public class JavaDefaultReadWriteSuite {
|
||||||
|
|
||||||
JavaSparkContext jsc = null;
|
JavaSparkContext jsc = null;
|
||||||
SQLContext sqlContext = null;
|
SparkSession spark = null;
|
||||||
File tempDir = null;
|
File tempDir = null;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
|
|
||||||
SQLContext.clearActive();
|
SQLContext.clearActive();
|
||||||
sqlContext = new SQLContext(jsc);
|
spark = SparkSession.builder()
|
||||||
SQLContext.setActive(sqlContext);
|
.master("local[2]")
|
||||||
|
.appName("JavaDefaultReadWriteSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
SQLContext.setActive(spark.wrapped());
|
||||||
|
|
||||||
tempDir = Utils.createTempDir(
|
tempDir = Utils.createTempDir(
|
||||||
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
|
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sqlContext = null;
|
|
||||||
SQLContext.clearActive();
|
SQLContext.clearActive();
|
||||||
if (jsc != null) {
|
if (spark != null) {
|
||||||
jsc.stop();
|
spark.stop();
|
||||||
jsc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
Utils.deleteRecursively(tempDir);
|
Utils.deleteRecursively(tempDir);
|
||||||
}
|
}
|
||||||
|
@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite {
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
// expected
|
// expected
|
||||||
}
|
}
|
||||||
instance.write().context(sqlContext).overwrite().save(outputPath);
|
instance.write().context(spark.wrapped()).overwrite().save(outputPath);
|
||||||
MyParams newInstance = MyParams.load(outputPath);
|
MyParams newInstance = MyParams.load(outputPath);
|
||||||
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
|
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
|
||||||
Assert.assertEquals("Params should be preserved.",
|
Assert.assertEquals("Params should be preserved.",
|
||||||
|
|
|
@ -27,26 +27,31 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaLogisticRegressionSuite implements Serializable {
|
public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaLogisticRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
|
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
|
||||||
int numAccurate = 0;
|
int numAccurate = 0;
|
||||||
for (LabeledPoint point: validationData) {
|
for (LabeledPoint point : validationData) {
|
||||||
Double prediction = model.predict(point.features());
|
Double prediction = model.predict(point.features());
|
||||||
if (prediction == point.label()) {
|
if (prediction == point.label()) {
|
||||||
numAccurate++;
|
numAccurate++;
|
||||||
|
@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double B = -1.5;
|
double B = -1.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
|
||||||
|
|
||||||
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
|
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
|
||||||
lrImpl.setIntercept(true);
|
lrImpl.setIntercept(true);
|
||||||
lrImpl.optimizer().setStepSize(1.0)
|
lrImpl.optimizer().setStepSize(1.0)
|
||||||
.setRegParam(1.0)
|
.setRegParam(1.0)
|
||||||
.setNumIterations(100);
|
.setNumIterations(100);
|
||||||
LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
|
LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
|
||||||
|
|
||||||
int numAccurate = validatePrediction(validationData, model);
|
int numAccurate = validatePrediction(validationData, model);
|
||||||
|
@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
|
||||||
double A = 0.0;
|
double A = 0.0;
|
||||||
double B = -2.5;
|
double B = -2.5;
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
|
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
|
||||||
|
|
||||||
LogisticRegressionModel model = LogisticRegressionWithSGD.train(
|
LogisticRegressionModel model = LogisticRegressionWithSGD.train(
|
||||||
testRDD.rdd(), 100, 1.0, 1.0);
|
testRDD.rdd(), 100, 1.0, 1.0);
|
||||||
|
|
||||||
int numAccurate = validatePrediction(validationData, model);
|
int numAccurate = validatePrediction(validationData, model);
|
||||||
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
|
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
|
||||||
|
|
|
@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaNaiveBayesSuite implements Serializable {
|
public class JavaNaiveBayesSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaNaiveBayesSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final List<LabeledPoint> POINTS = Arrays.asList(
|
private static final List<LabeledPoint> POINTS = Arrays.asList(
|
||||||
|
@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable {
|
||||||
|
|
||||||
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
|
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
|
||||||
int correct = 0;
|
int correct = 0;
|
||||||
for (LabeledPoint p: points) {
|
for (LabeledPoint p : points) {
|
||||||
if (model.predict(p.features()) == p.label()) {
|
if (model.predict(p.features()) == p.label()) {
|
||||||
correct += 1;
|
correct += 1;
|
||||||
}
|
}
|
||||||
|
@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void runUsingConstructor() {
|
public void runUsingConstructor() {
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
|
||||||
|
|
||||||
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
|
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
|
||||||
NaiveBayesModel model = nb.run(testRDD.rdd());
|
NaiveBayesModel model = nb.run(testRDD.rdd());
|
||||||
|
@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void runUsingStaticMethods() {
|
public void runUsingStaticMethods() {
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
|
||||||
|
|
||||||
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
|
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
|
||||||
int numAccurate1 = validatePrediction(POINTS, model1);
|
int numAccurate1 = validatePrediction(POINTS, model1);
|
||||||
|
@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPredictJavaRDD() {
|
public void testPredictJavaRDD() {
|
||||||
JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
|
JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
|
||||||
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
|
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
|
||||||
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
|
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
|
||||||
@Override
|
@Override
|
||||||
public Vector call(LabeledPoint v) throws Exception {
|
public Vector call(LabeledPoint v) throws Exception {
|
||||||
return v.features();
|
return v.features();
|
||||||
}});
|
}
|
||||||
|
});
|
||||||
JavaRDD<Double> predictions = model.predict(vectors);
|
JavaRDD<Double> predictions = model.predict(vectors);
|
||||||
// Should be able to get the first prediction.
|
// Should be able to get the first prediction.
|
||||||
predictions.first();
|
predictions.first();
|
||||||
|
|
|
@ -28,24 +28,30 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaSVMSuite implements Serializable {
|
public class JavaSVMSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaSVMSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaSVMSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
|
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
|
||||||
int numAccurate = 0;
|
int numAccurate = 0;
|
||||||
for (LabeledPoint point: validationData) {
|
for (LabeledPoint point : validationData) {
|
||||||
Double prediction = model.predict(point.features());
|
Double prediction = model.predict(point.features());
|
||||||
if (prediction == point.label()) {
|
if (prediction == point.label()) {
|
||||||
numAccurate++;
|
numAccurate++;
|
||||||
|
@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable {
|
||||||
double A = 2.0;
|
double A = 2.0;
|
||||||
double[] weights = {-1.5, 1.0};
|
double[] weights = {-1.5, 1.0};
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
|
||||||
weights, nPoints, 42), 2).cache();
|
weights, nPoints, 42), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
|
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
|
||||||
|
|
||||||
SVMWithSGD svmSGDImpl = new SVMWithSGD();
|
SVMWithSGD svmSGDImpl = new SVMWithSGD();
|
||||||
svmSGDImpl.setIntercept(true);
|
svmSGDImpl.setIntercept(true);
|
||||||
svmSGDImpl.optimizer().setStepSize(1.0)
|
svmSGDImpl.optimizer().setStepSize(1.0)
|
||||||
.setRegParam(1.0)
|
.setRegParam(1.0)
|
||||||
.setNumIterations(100);
|
.setNumIterations(100);
|
||||||
SVMModel model = svmSGDImpl.run(testRDD.rdd());
|
SVMModel model = svmSGDImpl.run(testRDD.rdd());
|
||||||
|
|
||||||
int numAccurate = validatePrediction(validationData, model);
|
int numAccurate = validatePrediction(validationData, model);
|
||||||
|
@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable {
|
||||||
double A = 0.0;
|
double A = 0.0;
|
||||||
double[] weights = {-1.5, 1.0};
|
double[] weights = {-1.5, 1.0};
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
|
||||||
weights, nPoints, 42), 2).cache();
|
weights, nPoints, 42), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
|
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
|
||||||
|
|
||||||
SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);
|
SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaBisectingKMeansSuite implements Serializable {
|
public class JavaBisectingKMeansSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", this.getClass().getSimpleName());
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaBisectingKMeansSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void twoDimensionalData() {
|
public void twoDimensionalData() {
|
||||||
JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
|
JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList(
|
||||||
Vectors.dense(4, -1),
|
Vectors.dense(4, -1),
|
||||||
Vectors.dense(4, 1),
|
Vectors.dense(4, 1),
|
||||||
Vectors.sparse(2, new int[] {0}, new double[] {1.0})
|
Vectors.sparse(2, new int[]{0}, new double[]{1.0})
|
||||||
), 2);
|
), 2);
|
||||||
|
|
||||||
BisectingKMeans bkm = new BisectingKMeans()
|
BisectingKMeans bkm = new BisectingKMeans()
|
||||||
|
@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable {
|
||||||
.setSeed(1L);
|
.setSeed(1L);
|
||||||
BisectingKMeansModel model = bkm.run(points);
|
BisectingKMeansModel model = bkm.run(points);
|
||||||
Assert.assertEquals(3, model.k());
|
Assert.assertEquals(3, model.k());
|
||||||
Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12);
|
Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12);
|
||||||
for (ClusteringTreeNode child: model.root().children()) {
|
for (ClusteringTreeNode child : model.root().children()) {
|
||||||
double[] center = child.center().toArray();
|
double[] center = child.center().toArray();
|
||||||
if (center[0] > 2) {
|
if (center[0] > 2) {
|
||||||
Assert.assertEquals(2, child.size());
|
Assert.assertEquals(2, child.size());
|
||||||
Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
|
Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
|
||||||
} else {
|
} else {
|
||||||
Assert.assertEquals(1, child.size());
|
Assert.assertEquals(1, child.size());
|
||||||
Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
|
Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,29 +21,35 @@ import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaGaussianMixtureSuite implements Serializable {
|
public class JavaGaussianMixtureSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaGaussianMixture");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaGaussianMixture")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable {
|
||||||
Vectors.dense(1.0, 4.0, 6.0)
|
Vectors.dense(1.0, 4.0, 6.0)
|
||||||
);
|
);
|
||||||
|
|
||||||
JavaRDD<Vector> data = sc.parallelize(points, 2);
|
JavaRDD<Vector> data = jsc.parallelize(points, 2);
|
||||||
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
|
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
|
||||||
.run(data);
|
.run(data);
|
||||||
assertEquals(model.gaussians().length, 2);
|
assertEquals(model.gaussians().length, 2);
|
||||||
|
|
|
@ -21,28 +21,35 @@ import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaKMeansSuite implements Serializable {
|
public class JavaKMeansSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaKMeans");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaKMeans")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
|
||||||
|
|
||||||
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
|
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
|
||||||
|
|
||||||
JavaRDD<Vector> data = sc.parallelize(points, 2);
|
JavaRDD<Vector> data = jsc.parallelize(points, 2);
|
||||||
KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
|
KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
|
||||||
assertEquals(1, model.clusterCenters().length);
|
assertEquals(1, model.clusterCenters().length);
|
||||||
assertEquals(expectedCenter, model.clusterCenters()[0]);
|
assertEquals(expectedCenter, model.clusterCenters()[0]);
|
||||||
|
@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable {
|
||||||
|
|
||||||
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
|
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
|
||||||
|
|
||||||
JavaRDD<Vector> data = sc.parallelize(points, 2);
|
JavaRDD<Vector> data = jsc.parallelize(points, 2);
|
||||||
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
|
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
|
||||||
assertEquals(1, model.clusterCenters().length);
|
assertEquals(1, model.clusterCenters().length);
|
||||||
assertEquals(expectedCenter, model.clusterCenters()[0]);
|
assertEquals(expectedCenter, model.clusterCenters()[0]);
|
||||||
|
@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable {
|
||||||
Vectors.dense(1.0, 3.0, 0.0),
|
Vectors.dense(1.0, 3.0, 0.0),
|
||||||
Vectors.dense(1.0, 4.0, 6.0)
|
Vectors.dense(1.0, 4.0, 6.0)
|
||||||
);
|
);
|
||||||
JavaRDD<Vector> data = sc.parallelize(points, 2);
|
JavaRDD<Vector> data = jsc.parallelize(points, 2);
|
||||||
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
|
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
|
||||||
JavaRDD<Integer> predictions = model.predict(data);
|
JavaRDD<Integer> predictions = model.predict(data);
|
||||||
// Should be able to get the first prediction.
|
// Should be able to get the first prediction.
|
||||||
|
|
|
@ -27,37 +27,42 @@ import scala.Tuple3;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
import org.apache.spark.api.java.function.Function;
|
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.api.java.function.Function;
|
||||||
import org.apache.spark.mllib.linalg.Matrix;
|
import org.apache.spark.mllib.linalg.Matrix;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaLDASuite implements Serializable {
|
public class JavaLDASuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaLDA");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaLDASuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
|
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
|
||||||
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
|
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
|
||||||
tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(),
|
tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
|
||||||
LDASuite.tinyCorpus()[i]._2()));
|
LDASuite.tinyCorpus()[i]._2()));
|
||||||
}
|
}
|
||||||
JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
|
JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
|
||||||
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
|
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable {
|
||||||
.setMaxIterations(5)
|
.setMaxIterations(5)
|
||||||
.setSeed(12345);
|
.setSeed(12345);
|
||||||
|
|
||||||
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
|
DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
|
||||||
|
|
||||||
// Check: basic parameters
|
// Check: basic parameters
|
||||||
LocalLDAModel localModel = model.toLocal();
|
LocalLDAModel localModel = model.toLocal();
|
||||||
|
@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable {
|
||||||
public Boolean call(Tuple2<Long, Vector> tuple2) {
|
public Boolean call(Tuple2<Long, Vector> tuple2) {
|
||||||
return Vectors.norm(tuple2._2(), 1.0) != 0.0;
|
return Vectors.norm(tuple2._2(), 1.0) != 0.0;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
|
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
|
||||||
|
|
||||||
// Check: javaTopTopicsPerDocuments
|
// Check: javaTopTopicsPerDocuments
|
||||||
|
@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void localLdaMethods() {
|
public void localLdaMethods() {
|
||||||
JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2);
|
JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
|
||||||
JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
|
JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
|
||||||
|
|
||||||
// check: topicDistributions
|
// check: topicDistributions
|
||||||
|
@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable {
|
||||||
// check: logLikelihood.
|
// check: logLikelihood.
|
||||||
ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
|
ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
|
||||||
docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
|
docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
|
||||||
JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
|
JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
|
||||||
double logLikelihood = toyModel.logLikelihood(single);
|
double logLikelihood = toyModel.logLikelihood(single);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable {
|
||||||
private static int tinyVocabSize = LDASuite.tinyVocabSize();
|
private static int tinyVocabSize = LDASuite.tinyVocabSize();
|
||||||
private static Matrix tinyTopics = LDASuite.tinyTopics();
|
private static Matrix tinyTopics = LDASuite.tinyTopics();
|
||||||
private static Tuple2<int[], double[]>[] tinyTopicDescription =
|
private static Tuple2<int[], double[]>[] tinyTopicDescription =
|
||||||
LDASuite.tinyTopicDescription();
|
LDASuite.tinyTopicDescription();
|
||||||
private JavaPairRDD<Long, Vector> corpus;
|
private JavaPairRDD<Long, Vector> corpus;
|
||||||
private LocalLDAModel toyModel = LDASuite.toyModel();
|
private LocalLDAModel toyModel = LDASuite.toyModel();
|
||||||
private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();
|
private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();
|
||||||
|
|
|
@ -27,8 +27,6 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.apache.spark.streaming.JavaTestUtils.*;
|
|
||||||
|
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
|
@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration;
|
||||||
import org.apache.spark.streaming.api.java.JavaDStream;
|
import org.apache.spark.streaming.api.java.JavaDStream;
|
||||||
import org.apache.spark.streaming.api.java.JavaPairDStream;
|
import org.apache.spark.streaming.api.java.JavaPairDStream;
|
||||||
import org.apache.spark.streaming.api.java.JavaStreamingContext;
|
import org.apache.spark.streaming.api.java.JavaStreamingContext;
|
||||||
|
import static org.apache.spark.streaming.JavaTestUtils.*;
|
||||||
|
|
||||||
public class JavaStreamingKMeansSuite implements Serializable {
|
public class JavaStreamingKMeansSuite implements Serializable {
|
||||||
|
|
||||||
|
|
|
@ -31,27 +31,34 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaRankingMetricsSuite implements Serializable {
|
public class JavaRankingMetricsSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
|
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
|
spark = SparkSession.builder()
|
||||||
predictionAndLabels = sc.parallelize(Arrays.asList(
|
.master("local")
|
||||||
|
.appName("JavaPCASuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
|
predictionAndLabels = jsc.parallelize(Arrays.asList(
|
||||||
Tuple2$.MODULE$.apply(
|
Tuple2$.MODULE$.apply(
|
||||||
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
|
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
|
||||||
Tuple2$.MODULE$.apply(
|
Tuple2$.MODULE$.apply(
|
||||||
Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
|
Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
|
||||||
Tuple2$.MODULE$.apply(
|
Tuple2$.MODULE$.apply(
|
||||||
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
|
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -29,19 +29,25 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaTfIdfSuite implements Serializable {
|
public class JavaTfIdfSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaTfIdfSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaPCASuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable {
|
||||||
// The tests are to check Java compatibility.
|
// The tests are to check Java compatibility.
|
||||||
HashingTF tf = new HashingTF();
|
HashingTF tf = new HashingTF();
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
|
JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
|
||||||
Arrays.asList("this is a sentence".split(" ")),
|
Arrays.asList("this is a sentence".split(" ")),
|
||||||
Arrays.asList("this is another sentence".split(" ")),
|
Arrays.asList("this is another sentence".split(" ")),
|
||||||
Arrays.asList("this is still a sentence".split(" "))), 2);
|
Arrays.asList("this is still a sentence".split(" "))), 2);
|
||||||
|
@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable {
|
||||||
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
|
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
|
||||||
List<Vector> localTfIdfs = tfIdfs.collect();
|
List<Vector> localTfIdfs = tfIdfs.collect();
|
||||||
int indexOfThis = tf.indexOf("this");
|
int indexOfThis = tf.indexOf("this");
|
||||||
for (Vector v: localTfIdfs) {
|
for (Vector v : localTfIdfs) {
|
||||||
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
|
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable {
|
||||||
// The tests are to check Java compatibility.
|
// The tests are to check Java compatibility.
|
||||||
HashingTF tf = new HashingTF();
|
HashingTF tf = new HashingTF();
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
|
JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
|
||||||
Arrays.asList("this is a sentence".split(" ")),
|
Arrays.asList("this is a sentence".split(" ")),
|
||||||
Arrays.asList("this is another sentence".split(" ")),
|
Arrays.asList("this is another sentence".split(" ")),
|
||||||
Arrays.asList("this is still a sentence".split(" "))), 2);
|
Arrays.asList("this is still a sentence".split(" "))), 2);
|
||||||
|
@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable {
|
||||||
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
|
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
|
||||||
List<Vector> localTfIdfs = tfIdfs.collect();
|
List<Vector> localTfIdfs = tfIdfs.collect();
|
||||||
int indexOfThis = tf.indexOf("this");
|
int indexOfThis = tf.indexOf("this");
|
||||||
for (Vector v: localTfIdfs) {
|
for (Vector v : localTfIdfs) {
|
||||||
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
|
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,9 +21,10 @@ import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.google.common.base.Strings;
|
||||||
|
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import com.google.common.base.Strings;
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -31,19 +32,25 @@ import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaWord2VecSuite implements Serializable {
|
public class JavaWord2VecSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaWord2VecSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaPCASuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable {
|
||||||
String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
|
String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
|
||||||
List<String> words = Arrays.asList(sentence.split(" "));
|
List<String> words = Arrays.asList(sentence.split(" "));
|
||||||
List<List<String>> localDoc = Arrays.asList(words, words);
|
List<List<String>> localDoc = Arrays.asList(words, words);
|
||||||
JavaRDD<List<String>> doc = sc.parallelize(localDoc);
|
JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
|
||||||
Word2Vec word2vec = new Word2Vec()
|
Word2Vec word2vec = new Word2Vec()
|
||||||
.setVectorSize(10)
|
.setVectorSize(10)
|
||||||
.setSeed(42L);
|
.setSeed(42L);
|
||||||
|
|
|
@ -26,32 +26,37 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
|
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaAssociationRulesSuite implements Serializable {
|
public class JavaAssociationRulesSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaFPGrowth");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaAssociationRulesSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void runAssociationRules() {
|
public void runAssociationRules() {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Arrays.asList(
|
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = jsc.parallelize(Arrays.asList(
|
||||||
new FreqItemset<String>(new String[] {"a"}, 15L),
|
new FreqItemset<String>(new String[]{"a"}, 15L),
|
||||||
new FreqItemset<String>(new String[] {"b"}, 35L),
|
new FreqItemset<String>(new String[]{"b"}, 35L),
|
||||||
new FreqItemset<String>(new String[] {"a", "b"}, 12L)
|
new FreqItemset<String>(new String[]{"a", "b"}, 12L)
|
||||||
));
|
));
|
||||||
|
|
||||||
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
|
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,34 +22,41 @@ import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.util.Utils;
|
import org.apache.spark.util.Utils;
|
||||||
|
|
||||||
public class JavaFPGrowthSuite implements Serializable {
|
public class JavaFPGrowthSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaFPGrowth");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaFPGrowth")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void runFPGrowth() {
|
public void runFPGrowth() {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
|
JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
|
||||||
Arrays.asList("r z h k p".split(" ")),
|
Arrays.asList("r z h k p".split(" ")),
|
||||||
Arrays.asList("z y x w v u t s".split(" ")),
|
Arrays.asList("z y x w v u t s".split(" ")),
|
||||||
Arrays.asList("s x o n r".split(" ")),
|
Arrays.asList("s x o n r".split(" ")),
|
||||||
|
@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable {
|
||||||
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
|
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
|
||||||
assertEquals(18, freqItemsets.size());
|
assertEquals(18, freqItemsets.size());
|
||||||
|
|
||||||
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
|
for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
|
||||||
// Test return types.
|
// Test return types.
|
||||||
List<String> items = itemset.javaItems();
|
List<String> items = itemset.javaItems();
|
||||||
long freq = itemset.freq();
|
long freq = itemset.freq();
|
||||||
|
@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable {
|
||||||
public void runFPGrowthSaveLoad() {
|
public void runFPGrowthSaveLoad() {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
|
JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
|
||||||
Arrays.asList("r z h k p".split(" ")),
|
Arrays.asList("r z h k p".split(" ")),
|
||||||
Arrays.asList("z y x w v u t s".split(" ")),
|
Arrays.asList("z y x w v u t s".split(" ")),
|
||||||
Arrays.asList("s x o n r".split(" ")),
|
Arrays.asList("s x o n r".split(" ")),
|
||||||
|
@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable {
|
||||||
String outputPath = tempDir.getPath();
|
String outputPath = tempDir.getPath();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
model.save(sc.sc(), outputPath);
|
model.save(spark.sparkContext(), outputPath);
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
FPGrowthModel<String> newModel =
|
FPGrowthModel<String> newModel =
|
||||||
(FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath);
|
(FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath);
|
||||||
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
|
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
|
||||||
.collect();
|
.collect();
|
||||||
assertEquals(18, freqItemsets.size());
|
assertEquals(18, freqItemsets.size());
|
||||||
|
|
||||||
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
|
for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
|
||||||
// Test return types.
|
// Test return types.
|
||||||
List<String> items = itemset.javaItems();
|
List<String> items = itemset.javaItems();
|
||||||
long freq = itemset.freq();
|
long freq = itemset.freq();
|
||||||
|
|
|
@ -29,25 +29,31 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
|
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.util.Utils;
|
import org.apache.spark.util.Utils;
|
||||||
|
|
||||||
public class JavaPrefixSpanSuite {
|
public class JavaPrefixSpanSuite {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaPrefixSpan");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaPrefixSpan")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void runPrefixSpan() {
|
public void runPrefixSpan() {
|
||||||
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
|
JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
|
||||||
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
|
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
|
||||||
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
|
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
|
||||||
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
|
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
|
||||||
|
@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite {
|
||||||
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
|
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
|
||||||
Assert.assertEquals(5, localFreqSeqs.size());
|
Assert.assertEquals(5, localFreqSeqs.size());
|
||||||
// Check that each frequent sequence could be materialized.
|
// Check that each frequent sequence could be materialized.
|
||||||
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
|
for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
|
||||||
List<List<Integer>> seq = freqSeq.javaSequence();
|
List<List<Integer>> seq = freqSeq.javaSequence();
|
||||||
long freq = freqSeq.freq();
|
long freq = freqSeq.freq();
|
||||||
}
|
}
|
||||||
|
@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void runPrefixSpanSaveLoad() {
|
public void runPrefixSpanSaveLoad() {
|
||||||
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
|
JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
|
||||||
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
|
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
|
||||||
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
|
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
|
||||||
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
|
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
|
||||||
|
@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite {
|
||||||
String outputPath = tempDir.getPath();
|
String outputPath = tempDir.getPath();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
model.save(sc.sc(), outputPath);
|
model.save(spark.sparkContext(), outputPath);
|
||||||
PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
|
PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath);
|
||||||
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
|
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
|
||||||
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
|
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
|
||||||
Assert.assertEquals(5, localFreqSeqs.size());
|
Assert.assertEquals(5, localFreqSeqs.size());
|
||||||
// Check that each frequent sequence could be materialized.
|
// Check that each frequent sequence could be materialized.
|
||||||
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
|
for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
|
||||||
List<List<Integer>> seq = freqSeq.javaSequence();
|
List<List<Integer>> seq = freqSeq.javaSequence();
|
||||||
long freq = freqSeq.freq();
|
long freq = freqSeq.freq();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,147 +17,149 @@
|
||||||
|
|
||||||
package org.apache.spark.mllib.linalg;
|
package org.apache.spark.mllib.linalg;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
public class JavaMatricesSuite implements Serializable {
|
public class JavaMatricesSuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void randMatrixConstruction() {
|
public void randMatrixConstruction() {
|
||||||
Random rng = new Random(24);
|
Random rng = new Random(24);
|
||||||
Matrix r = Matrices.rand(3, 4, rng);
|
Matrix r = Matrices.rand(3, 4, rng);
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
|
DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
|
||||||
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
|
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
|
||||||
|
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
Matrix rn = Matrices.randn(3, 4, rng);
|
Matrix rn = Matrices.randn(3, 4, rng);
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
|
DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
|
||||||
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
|
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
|
||||||
|
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
Matrix s = Matrices.sprand(3, 4, 0.5, rng);
|
Matrix s = Matrices.sprand(3, 4, 0.5, rng);
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
|
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
|
||||||
assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
|
assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
|
||||||
|
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
|
Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
|
||||||
rng.setSeed(24);
|
rng.setSeed(24);
|
||||||
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
|
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
|
||||||
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
|
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void identityMatrixConstruction() {
|
public void identityMatrixConstruction() {
|
||||||
Matrix r = Matrices.eye(2);
|
Matrix r = Matrices.eye(2);
|
||||||
DenseMatrix dr = DenseMatrix.eye(2);
|
DenseMatrix dr = DenseMatrix.eye(2);
|
||||||
SparseMatrix sr = SparseMatrix.speye(2);
|
SparseMatrix sr = SparseMatrix.speye(2);
|
||||||
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
|
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
|
||||||
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
|
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
|
||||||
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
|
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void diagonalMatrixConstruction() {
|
public void diagonalMatrixConstruction() {
|
||||||
Vector v = Vectors.dense(1.0, 0.0, 2.0);
|
Vector v = Vectors.dense(1.0, 0.0, 2.0);
|
||||||
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
|
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
|
||||||
|
|
||||||
Matrix m = Matrices.diag(v);
|
Matrix m = Matrices.diag(v);
|
||||||
Matrix sm = Matrices.diag(sv);
|
Matrix sm = Matrices.diag(sv);
|
||||||
DenseMatrix d = DenseMatrix.diag(v);
|
DenseMatrix d = DenseMatrix.diag(v);
|
||||||
DenseMatrix sd = DenseMatrix.diag(sv);
|
DenseMatrix sd = DenseMatrix.diag(sv);
|
||||||
SparseMatrix s = SparseMatrix.spdiag(v);
|
SparseMatrix s = SparseMatrix.spdiag(v);
|
||||||
SparseMatrix ss = SparseMatrix.spdiag(sv);
|
SparseMatrix ss = SparseMatrix.spdiag(sv);
|
||||||
|
|
||||||
assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
|
assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
|
||||||
assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
|
assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
|
||||||
assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
|
assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
|
||||||
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
|
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
|
||||||
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
|
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
|
||||||
assertArrayEquals(s.values(), ss.values(), 0.0);
|
assertArrayEquals(s.values(), ss.values(), 0.0);
|
||||||
assertEquals(2, s.values().length);
|
assertEquals(2, s.values().length);
|
||||||
assertEquals(2, ss.values().length);
|
assertEquals(2, ss.values().length);
|
||||||
assertEquals(4, s.colPtrs().length);
|
assertEquals(4, s.colPtrs().length);
|
||||||
assertEquals(4, ss.colPtrs().length);
|
assertEquals(4, ss.colPtrs().length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void zerosMatrixConstruction() {
|
public void zerosMatrixConstruction() {
|
||||||
Matrix z = Matrices.zeros(2, 2);
|
Matrix z = Matrices.zeros(2, 2);
|
||||||
Matrix one = Matrices.ones(2, 2);
|
Matrix one = Matrices.ones(2, 2);
|
||||||
DenseMatrix dz = DenseMatrix.zeros(2, 2);
|
DenseMatrix dz = DenseMatrix.zeros(2, 2);
|
||||||
DenseMatrix done = DenseMatrix.ones(2, 2);
|
DenseMatrix done = DenseMatrix.ones(2, 2);
|
||||||
|
|
||||||
assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
|
assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
|
||||||
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
|
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
|
||||||
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
|
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
|
||||||
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
|
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void sparseDenseConversion() {
|
public void sparseDenseConversion() {
|
||||||
int m = 3;
|
int m = 3;
|
||||||
int n = 2;
|
int n = 2;
|
||||||
double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
|
double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
|
||||||
double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
|
double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
|
||||||
int[] colPtrs = new int[]{0, 2, 4};
|
int[] colPtrs = new int[]{0, 2, 4};
|
||||||
int[] rowIndices = new int[]{0, 1, 1, 2};
|
int[] rowIndices = new int[]{0, 1, 1, 2};
|
||||||
|
|
||||||
SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
|
SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
|
||||||
DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
|
DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
|
||||||
|
|
||||||
SparseMatrix spMat2 = deMat1.toSparse();
|
SparseMatrix spMat2 = deMat1.toSparse();
|
||||||
DenseMatrix deMat2 = spMat1.toDense();
|
DenseMatrix deMat2 = spMat1.toDense();
|
||||||
|
|
||||||
assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
|
assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
|
||||||
assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
|
assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void concatenateMatrices() {
|
public void concatenateMatrices() {
|
||||||
int m = 3;
|
int m = 3;
|
||||||
int n = 2;
|
int n = 2;
|
||||||
|
|
||||||
Random rng = new Random(42);
|
Random rng = new Random(42);
|
||||||
SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
|
SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
|
||||||
rng.setSeed(42);
|
rng.setSeed(42);
|
||||||
DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
|
DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
|
||||||
Matrix deMat2 = Matrices.eye(3);
|
Matrix deMat2 = Matrices.eye(3);
|
||||||
Matrix spMat2 = Matrices.speye(3);
|
Matrix spMat2 = Matrices.speye(3);
|
||||||
Matrix deMat3 = Matrices.eye(2);
|
Matrix deMat3 = Matrices.eye(2);
|
||||||
Matrix spMat3 = Matrices.speye(2);
|
Matrix spMat3 = Matrices.speye(2);
|
||||||
|
|
||||||
Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
|
Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
|
||||||
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
|
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
|
||||||
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
|
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
|
||||||
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
|
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
|
||||||
|
|
||||||
assertEquals(3, deHorz1.numRows());
|
assertEquals(3, deHorz1.numRows());
|
||||||
assertEquals(3, deHorz2.numRows());
|
assertEquals(3, deHorz2.numRows());
|
||||||
assertEquals(3, deHorz3.numRows());
|
assertEquals(3, deHorz3.numRows());
|
||||||
assertEquals(3, spHorz.numRows());
|
assertEquals(3, spHorz.numRows());
|
||||||
assertEquals(5, deHorz1.numCols());
|
assertEquals(5, deHorz1.numCols());
|
||||||
assertEquals(5, deHorz2.numCols());
|
assertEquals(5, deHorz2.numCols());
|
||||||
assertEquals(5, deHorz3.numCols());
|
assertEquals(5, deHorz3.numCols());
|
||||||
assertEquals(5, spHorz.numCols());
|
assertEquals(5, spHorz.numCols());
|
||||||
|
|
||||||
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
|
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
|
||||||
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
|
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
|
||||||
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
|
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
|
||||||
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
|
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
|
||||||
|
|
||||||
assertEquals(5, deVert1.numRows());
|
assertEquals(5, deVert1.numRows());
|
||||||
assertEquals(5, deVert2.numRows());
|
assertEquals(5, deVert2.numRows());
|
||||||
assertEquals(5, deVert3.numRows());
|
assertEquals(5, deVert3.numRows());
|
||||||
assertEquals(5, spVert.numRows());
|
assertEquals(5, spVert.numRows());
|
||||||
assertEquals(2, deVert1.numCols());
|
assertEquals(2, deVert1.numCols());
|
||||||
assertEquals(2, deVert2.numCols());
|
assertEquals(2, deVert2.numCols());
|
||||||
assertEquals(2, deVert3.numCols());
|
assertEquals(2, deVert3.numCols());
|
||||||
assertEquals(2, spVert.numCols());
|
assertEquals(2, spVert.numCols());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
public class JavaVectorsSuite implements Serializable {
|
public class JavaVectorsSuite implements Serializable {
|
||||||
|
|
||||||
|
@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable {
|
||||||
public void sparseArrayConstruction() {
|
public void sparseArrayConstruction() {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
Vector v = Vectors.sparse(3, Arrays.asList(
|
Vector v = Vectors.sparse(3, Arrays.asList(
|
||||||
new Tuple2<>(0, 2.0),
|
new Tuple2<>(0, 2.0),
|
||||||
new Tuple2<>(2, 3.0)));
|
new Tuple2<>(2, 3.0)));
|
||||||
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
|
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,29 +20,35 @@ package org.apache.spark.mllib.random;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaDoubleRDD;
|
import org.apache.spark.api.java.JavaDoubleRDD;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
import static org.apache.spark.mllib.random.RandomRDDs.*;
|
import static org.apache.spark.mllib.random.RandomRDDs.*;
|
||||||
|
|
||||||
public class JavaRandomRDDsSuite {
|
public class JavaRandomRDDsSuite {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaRandomRDDsSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaRandomRDDsSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite {
|
||||||
long m = 1000L;
|
long m = 1000L;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
|
JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
|
||||||
JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
|
JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
|
||||||
JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
|
JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
|
||||||
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite {
|
||||||
long m = 1000L;
|
long m = 1000L;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
|
JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
|
||||||
JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
|
JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
|
||||||
JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
|
JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
|
||||||
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite {
|
||||||
long m = 1000L;
|
long m = 1000L;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m);
|
JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
|
||||||
JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p);
|
JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
|
||||||
JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed);
|
JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
|
||||||
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite {
|
||||||
long m = 1000L;
|
long m = 1000L;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
|
JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
|
||||||
JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
|
JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
|
||||||
JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
|
JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
|
||||||
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite {
|
||||||
long m = 1000L;
|
long m = 1000L;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m);
|
JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
|
||||||
JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p);
|
JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
|
||||||
JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed);
|
JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
|
||||||
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite {
|
||||||
@Test
|
@Test
|
||||||
public void testGammaRDD() {
|
public void testGammaRDD() {
|
||||||
double shape = 1.0;
|
double shape = 1.0;
|
||||||
double scale = 2.0;
|
double jscale = 2.0;
|
||||||
long m = 1000L;
|
long m = 1000L;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m);
|
JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
|
||||||
JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p);
|
JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
|
||||||
JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed);
|
JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
|
||||||
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite {
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n);
|
JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
|
||||||
JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p);
|
JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite {
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n);
|
JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
|
||||||
JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p);
|
JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite {
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n);
|
JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
|
||||||
JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p);
|
JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite {
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
|
JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
|
||||||
JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
|
JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite {
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n);
|
JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
|
||||||
JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p);
|
JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public void testGammaVectorRDD() {
|
public void testGammaVectorRDD() {
|
||||||
double shape = 1.0;
|
double shape = 1.0;
|
||||||
double scale = 2.0;
|
double jscale = 2.0;
|
||||||
long m = 100L;
|
long m = 100L;
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n);
|
JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
|
||||||
JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p);
|
JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite {
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
int numPartitions = 0;
|
int numPartitions = 0;
|
||||||
StringGenerator gen = new StringGenerator();
|
StringGenerator gen = new StringGenerator();
|
||||||
JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
|
JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
|
||||||
JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
|
JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
|
||||||
JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
|
JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
|
||||||
for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(size, rdd.count());
|
Assert.assertEquals(size, rdd.count());
|
||||||
Assert.assertEquals(2, rdd.first().length());
|
Assert.assertEquals(2, rdd.first().length());
|
||||||
}
|
}
|
||||||
|
@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite {
|
||||||
int n = 10;
|
int n = 10;
|
||||||
int p = 2;
|
int p = 2;
|
||||||
long seed = 1L;
|
long seed = 1L;
|
||||||
JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
|
JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
|
||||||
JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
|
JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
|
||||||
JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
|
JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
|
||||||
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
|
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
|
||||||
Assert.assertEquals(m, rdd.count());
|
Assert.assertEquals(m, rdd.count());
|
||||||
Assert.assertEquals(n, rdd.first().size());
|
Assert.assertEquals(n, rdd.first().size());
|
||||||
}
|
}
|
||||||
|
@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator<String>, Serializable {
|
||||||
public String nextValue() {
|
public String nextValue() {
|
||||||
return "42";
|
return "42";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StringGenerator copy() {
|
public StringGenerator copy() {
|
||||||
return new StringGenerator();
|
return new StringGenerator();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setSeed(long seed) {
|
public void setSeed(long seed) {
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,40 +32,46 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaALSSuite implements Serializable {
|
public class JavaALSSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaALS");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaALS")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validatePrediction(
|
private void validatePrediction(
|
||||||
MatrixFactorizationModel model,
|
MatrixFactorizationModel model,
|
||||||
int users,
|
int users,
|
||||||
int products,
|
int products,
|
||||||
double[] trueRatings,
|
double[] trueRatings,
|
||||||
double matchThreshold,
|
double matchThreshold,
|
||||||
boolean implicitPrefs,
|
boolean implicitPrefs,
|
||||||
double[] truePrefs) {
|
double[] truePrefs) {
|
||||||
List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
|
List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
|
||||||
for (int u=0; u < users; ++u) {
|
for (int u = 0; u < users; ++u) {
|
||||||
for (int p=0; p < products; ++p) {
|
for (int p = 0; p < products; ++p) {
|
||||||
localUsersProducts.add(new Tuple2<>(u, p));
|
localUsersProducts.add(new Tuple2<>(u, p));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
|
JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
|
||||||
List<Rating> predictedRatings = model.predict(usersProducts).collect();
|
List<Rating> predictedRatings = model.predict(usersProducts).collect();
|
||||||
Assert.assertEquals(users * products, predictedRatings.size());
|
Assert.assertEquals(users * products, predictedRatings.size());
|
||||||
if (!implicitPrefs) {
|
if (!implicitPrefs) {
|
||||||
for (Rating r: predictedRatings) {
|
for (Rating r : predictedRatings) {
|
||||||
double prediction = r.rating();
|
double prediction = r.rating();
|
||||||
double correct = trueRatings[r.product() * users + r.user()];
|
double correct = trueRatings[r.product() * users + r.user()];
|
||||||
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
|
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
|
||||||
|
@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable {
|
||||||
// (ref Mahout's implicit ALS tests)
|
// (ref Mahout's implicit ALS tests)
|
||||||
double sqErr = 0.0;
|
double sqErr = 0.0;
|
||||||
double denom = 0.0;
|
double denom = 0.0;
|
||||||
for (Rating r: predictedRatings) {
|
for (Rating r : predictedRatings) {
|
||||||
double prediction = r.rating();
|
double prediction = r.rating();
|
||||||
double truePref = truePrefs[r.product() * users + r.user()];
|
double truePref = truePrefs[r.product() * users + r.user()];
|
||||||
double confidence = 1.0 +
|
double confidence = 1.0 +
|
||||||
|
@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable {
|
||||||
int users = 50;
|
int users = 50;
|
||||||
int products = 100;
|
int products = 100;
|
||||||
Tuple3<List<Rating>, double[], double[]> testData =
|
Tuple3<List<Rating>, double[], double[]> testData =
|
||||||
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
|
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
|
||||||
|
|
||||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
JavaRDD<Rating> data = jsc.parallelize(testData._1());
|
||||||
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
|
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
|
||||||
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
|
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
|
||||||
}
|
}
|
||||||
|
@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable {
|
||||||
int users = 100;
|
int users = 100;
|
||||||
int products = 200;
|
int products = 200;
|
||||||
Tuple3<List<Rating>, double[], double[]> testData =
|
Tuple3<List<Rating>, double[], double[]> testData =
|
||||||
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
|
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
|
||||||
|
|
||||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
JavaRDD<Rating> data = jsc.parallelize(testData._1());
|
||||||
|
|
||||||
MatrixFactorizationModel model = new ALS().setRank(features)
|
MatrixFactorizationModel model = new ALS().setRank(features)
|
||||||
.setIterations(iterations)
|
.setIterations(iterations)
|
||||||
|
@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable {
|
||||||
int users = 80;
|
int users = 80;
|
||||||
int products = 160;
|
int products = 160;
|
||||||
Tuple3<List<Rating>, double[], double[]> testData =
|
Tuple3<List<Rating>, double[], double[]> testData =
|
||||||
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
|
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
|
||||||
|
|
||||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
JavaRDD<Rating> data = jsc.parallelize(testData._1());
|
||||||
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
|
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
|
||||||
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
|
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
|
||||||
}
|
}
|
||||||
|
@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable {
|
||||||
int users = 100;
|
int users = 100;
|
||||||
int products = 200;
|
int products = 200;
|
||||||
Tuple3<List<Rating>, double[], double[]> testData =
|
Tuple3<List<Rating>, double[], double[]> testData =
|
||||||
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
|
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
|
||||||
|
|
||||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
JavaRDD<Rating> data = jsc.parallelize(testData._1());
|
||||||
|
|
||||||
MatrixFactorizationModel model = new ALS().setRank(features)
|
MatrixFactorizationModel model = new ALS().setRank(features)
|
||||||
.setIterations(iterations)
|
.setIterations(iterations)
|
||||||
|
@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable {
|
||||||
int users = 80;
|
int users = 80;
|
||||||
int products = 160;
|
int products = 160;
|
||||||
Tuple3<List<Rating>, double[], double[]> testData =
|
Tuple3<List<Rating>, double[], double[]> testData =
|
||||||
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
|
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
|
||||||
|
|
||||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
JavaRDD<Rating> data = jsc.parallelize(testData._1());
|
||||||
MatrixFactorizationModel model = new ALS().setRank(features)
|
MatrixFactorizationModel model = new ALS().setRank(features)
|
||||||
.setIterations(iterations)
|
.setIterations(iterations)
|
||||||
.setImplicitPrefs(true)
|
.setImplicitPrefs(true)
|
||||||
|
@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable {
|
||||||
int users = 200;
|
int users = 200;
|
||||||
int products = 50;
|
int products = 50;
|
||||||
List<Rating> testData = ALSSuite.generateRatingsAsJava(
|
List<Rating> testData = ALSSuite.generateRatingsAsJava(
|
||||||
users, products, features, 0.7, true, false)._1();
|
users, products, features, 0.7, true, false)._1();
|
||||||
JavaRDD<Rating> data = sc.parallelize(testData);
|
JavaRDD<Rating> data = jsc.parallelize(testData);
|
||||||
MatrixFactorizationModel model = new ALS().setRank(features)
|
MatrixFactorizationModel model = new ALS().setRank(features)
|
||||||
.setIterations(iterations)
|
.setIterations(iterations)
|
||||||
.setImplicitPrefs(true)
|
.setImplicitPrefs(true)
|
||||||
|
@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable {
|
||||||
private static void validateRecommendations(Rating[] recommendations, int howMany) {
|
private static void validateRecommendations(Rating[] recommendations, int howMany) {
|
||||||
Assert.assertEquals(howMany, recommendations.length);
|
Assert.assertEquals(howMany, recommendations.length);
|
||||||
for (int i = 1; i < recommendations.length; i++) {
|
for (int i = 1; i < recommendations.length; i++) {
|
||||||
Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating());
|
Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
|
||||||
}
|
}
|
||||||
Assert.assertTrue(recommendations[0].rating() > 0.7);
|
Assert.assertTrue(recommendations[0].rating() > 0.7);
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,15 +32,17 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaDoubleRDD;
|
import org.apache.spark.api.java.JavaDoubleRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaIsotonicRegressionSuite implements Serializable {
|
public class JavaIsotonicRegressionSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
|
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
|
||||||
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
|
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
|
||||||
|
|
||||||
for (int i = 1; i <= labels.length; i++) {
|
for (int i = 1; i <= labels.length; i++) {
|
||||||
input.add(new Tuple3<>(labels[i-1], (double) i, 1.0));
|
input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
return input;
|
return input;
|
||||||
|
@ -48,20 +50,24 @@ public class JavaIsotonicRegressionSuite implements Serializable {
|
||||||
|
|
||||||
private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
|
private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
|
||||||
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
|
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
|
||||||
sc.parallelize(generateIsotonicInput(labels), 2).cache();
|
jsc.parallelize(generateIsotonicInput(labels), 2).cache();
|
||||||
|
|
||||||
return new IsotonicRegression().run(trainRDD);
|
return new IsotonicRegression().run(trainRDD);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaLinearRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -70,7 +76,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
|
||||||
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
|
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
|
||||||
|
|
||||||
Assert.assertArrayEquals(
|
Assert.assertArrayEquals(
|
||||||
new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
|
new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -78,7 +84,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
|
||||||
IsotonicRegressionModel model =
|
IsotonicRegressionModel model =
|
||||||
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
|
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
|
||||||
|
|
||||||
JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
|
JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
|
||||||
List<Double> predictions = model.predict(testRDD).collect();
|
List<Double> predictions = model.predict(testRDD).collect();
|
||||||
|
|
||||||
Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);
|
Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);
|
||||||
|
|
|
@ -28,24 +28,30 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.util.LinearDataGenerator;
|
import org.apache.spark.mllib.util.LinearDataGenerator;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaLassoSuite implements Serializable {
|
public class JavaLassoSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaLassoSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaLassoSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
|
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
|
||||||
int numAccurate = 0;
|
int numAccurate = 0;
|
||||||
for (LabeledPoint point: validationData) {
|
for (LabeledPoint point : validationData) {
|
||||||
Double prediction = model.predict(point.features());
|
Double prediction = model.predict(point.features());
|
||||||
// A prediction is off if the prediction is more than 0.5 away from expected value.
|
// A prediction is off if the prediction is more than 0.5 away from expected value.
|
||||||
if (Math.abs(prediction - point.label()) <= 0.5) {
|
if (Math.abs(prediction - point.label()) <= 0.5) {
|
||||||
|
@ -61,15 +67,15 @@ public class JavaLassoSuite implements Serializable {
|
||||||
double A = 0.0;
|
double A = 0.0;
|
||||||
double[] weights = {-1.5, 1.0e-2};
|
double[] weights = {-1.5, 1.0e-2};
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
|
||||||
weights, nPoints, 42, 0.1), 2).cache();
|
weights, nPoints, 42, 0.1), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
||||||
|
|
||||||
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
|
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
|
||||||
lassoSGDImpl.optimizer().setStepSize(1.0)
|
lassoSGDImpl.optimizer().setStepSize(1.0)
|
||||||
.setRegParam(0.01)
|
.setRegParam(0.01)
|
||||||
.setNumIterations(20);
|
.setNumIterations(20);
|
||||||
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
|
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
|
||||||
|
|
||||||
int numAccurate = validatePrediction(validationData, model);
|
int numAccurate = validatePrediction(validationData, model);
|
||||||
|
@ -82,10 +88,10 @@ public class JavaLassoSuite implements Serializable {
|
||||||
double A = 0.0;
|
double A = 0.0;
|
||||||
double[] weights = {-1.5, 1.0e-2};
|
double[] weights = {-1.5, 1.0e-2};
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
|
||||||
weights, nPoints, 42, 0.1), 2).cache();
|
weights, nPoints, 42, 0.1), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
||||||
|
|
||||||
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
|
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
|
||||||
|
|
||||||
|
|
|
@ -25,34 +25,40 @@ import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import org.apache.spark.api.java.function.Function;
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.api.java.function.Function;
|
||||||
import org.apache.spark.mllib.linalg.Vector;
|
import org.apache.spark.mllib.linalg.Vector;
|
||||||
import org.apache.spark.mllib.util.LinearDataGenerator;
|
import org.apache.spark.mllib.util.LinearDataGenerator;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaLinearRegressionSuite implements Serializable {
|
public class JavaLinearRegressionSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaLinearRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
|
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
|
||||||
int numAccurate = 0;
|
int numAccurate = 0;
|
||||||
for (LabeledPoint point: validationData) {
|
for (LabeledPoint point : validationData) {
|
||||||
Double prediction = model.predict(point.features());
|
Double prediction = model.predict(point.features());
|
||||||
// A prediction is off if the prediction is more than 0.5 away from expected value.
|
// A prediction is off if the prediction is more than 0.5 away from expected value.
|
||||||
if (Math.abs(prediction - point.label()) <= 0.5) {
|
if (Math.abs(prediction - point.label()) <= 0.5) {
|
||||||
numAccurate++;
|
numAccurate++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return numAccurate;
|
return numAccurate;
|
||||||
}
|
}
|
||||||
|
@ -63,10 +69,10 @@ public class JavaLinearRegressionSuite implements Serializable {
|
||||||
double A = 3.0;
|
double A = 3.0;
|
||||||
double[] weights = {10, 10};
|
double[] weights = {10, 10};
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
||||||
|
|
||||||
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
|
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
|
||||||
linSGDImpl.setIntercept(true);
|
linSGDImpl.setIntercept(true);
|
||||||
|
@ -82,10 +88,10 @@ public class JavaLinearRegressionSuite implements Serializable {
|
||||||
double A = 0.0;
|
double A = 0.0;
|
||||||
double[] weights = {10, 10};
|
double[] weights = {10, 10};
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
|
||||||
List<LabeledPoint> validationData =
|
List<LabeledPoint> validationData =
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
|
||||||
|
|
||||||
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
|
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
|
||||||
|
|
||||||
|
@ -98,7 +104,7 @@ public class JavaLinearRegressionSuite implements Serializable {
|
||||||
int nPoints = 100;
|
int nPoints = 100;
|
||||||
double A = 0.0;
|
double A = 0.0;
|
||||||
double[] weights = {10, 10};
|
double[] weights = {10, 10};
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
|
||||||
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
|
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
|
||||||
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
|
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
|
||||||
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
|
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
|
||||||
|
|
|
@ -29,25 +29,31 @@ import org.junit.Test;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.util.LinearDataGenerator;
|
import org.apache.spark.mllib.util.LinearDataGenerator;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
public class JavaRidgeRegressionSuite implements Serializable {
|
public class JavaRidgeRegressionSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaRidgeRegressionSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static double predictionError(List<LabeledPoint> validationData,
|
private static double predictionError(List<LabeledPoint> validationData,
|
||||||
RidgeRegressionModel model) {
|
RidgeRegressionModel model) {
|
||||||
double errorSum = 0;
|
double errorSum = 0;
|
||||||
for (LabeledPoint point: validationData) {
|
for (LabeledPoint point : validationData) {
|
||||||
Double prediction = model.predict(point.features());
|
Double prediction = model.predict(point.features());
|
||||||
errorSum += (prediction - point.label()) * (prediction - point.label());
|
errorSum += (prediction - point.label()) * (prediction - point.label());
|
||||||
}
|
}
|
||||||
|
@ -68,9 +74,9 @@ public class JavaRidgeRegressionSuite implements Serializable {
|
||||||
public void runRidgeRegressionUsingConstructor() {
|
public void runRidgeRegressionUsingConstructor() {
|
||||||
int numExamples = 50;
|
int numExamples = 50;
|
||||||
int numFeatures = 20;
|
int numFeatures = 20;
|
||||||
List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);
|
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
|
||||||
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
|
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
|
||||||
|
|
||||||
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
|
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
|
||||||
|
@ -94,7 +100,7 @@ public class JavaRidgeRegressionSuite implements Serializable {
|
||||||
int numFeatures = 20;
|
int numFeatures = 20;
|
||||||
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
|
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
|
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
|
||||||
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
|
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
|
||||||
|
|
||||||
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
|
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
|
||||||
|
|
|
@ -24,13 +24,11 @@ import java.util.List;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.apache.spark.streaming.JavaTestUtils.*;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
|
||||||
import org.apache.spark.api.java.JavaDoubleRDD;
|
import org.apache.spark.api.java.JavaDoubleRDD;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.mllib.linalg.Vectors;
|
import org.apache.spark.mllib.linalg.Vectors;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
@ -38,36 +36,42 @@ import org.apache.spark.mllib.stat.test.BinarySample;
|
||||||
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
|
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
|
||||||
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
|
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
|
||||||
import org.apache.spark.mllib.stat.test.StreamingTest;
|
import org.apache.spark.mllib.stat.test.StreamingTest;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.apache.spark.streaming.Duration;
|
import org.apache.spark.streaming.Duration;
|
||||||
import org.apache.spark.streaming.api.java.JavaDStream;
|
import org.apache.spark.streaming.api.java.JavaDStream;
|
||||||
import org.apache.spark.streaming.api.java.JavaStreamingContext;
|
import org.apache.spark.streaming.api.java.JavaStreamingContext;
|
||||||
|
import static org.apache.spark.streaming.JavaTestUtils.*;
|
||||||
|
|
||||||
public class JavaStatisticsSuite implements Serializable {
|
public class JavaStatisticsSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
private transient JavaStreamingContext ssc;
|
private transient JavaStreamingContext ssc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
SparkConf conf = new SparkConf()
|
SparkConf conf = new SparkConf()
|
||||||
.setMaster("local[2]")
|
|
||||||
.setAppName("JavaStatistics")
|
|
||||||
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
|
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
|
||||||
sc = new JavaSparkContext(conf);
|
spark = SparkSession.builder()
|
||||||
ssc = new JavaStreamingContext(sc, new Duration(1000));
|
.master("local[2]")
|
||||||
|
.appName("JavaStatistics")
|
||||||
|
.config(conf)
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
ssc = new JavaStreamingContext(jsc, new Duration(1000));
|
||||||
ssc.checkpoint("checkpoint");
|
ssc.checkpoint("checkpoint");
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
|
spark.stop();
|
||||||
ssc.stop();
|
ssc.stop();
|
||||||
ssc = null;
|
spark = null;
|
||||||
sc = null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCorr() {
|
public void testCorr() {
|
||||||
JavaRDD<Double> x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
|
JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
|
||||||
JavaRDD<Double> y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
|
JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
|
||||||
|
|
||||||
Double corr1 = Statistics.corr(x, y);
|
Double corr1 = Statistics.corr(x, y);
|
||||||
Double corr2 = Statistics.corr(x, y, "pearson");
|
Double corr2 = Statistics.corr(x, y, "pearson");
|
||||||
|
@ -77,7 +81,7 @@ public class JavaStatisticsSuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void kolmogorovSmirnovTest() {
|
public void kolmogorovSmirnovTest() {
|
||||||
JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
|
JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
|
||||||
KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
|
KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
|
||||||
KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
|
KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
|
||||||
data, "norm", 0.0, 1.0);
|
data, "norm", 0.0, 1.0);
|
||||||
|
@ -85,7 +89,7 @@ public class JavaStatisticsSuite implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void chiSqTest() {
|
public void chiSqTest() {
|
||||||
JavaRDD<LabeledPoint> data = sc.parallelize(Arrays.asList(
|
JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
|
||||||
new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
|
new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
|
||||||
new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
|
new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
|
||||||
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
|
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
|
||||||
|
|
|
@ -35,25 +35,31 @@ import org.apache.spark.mllib.tree.configuration.Algo;
|
||||||
import org.apache.spark.mllib.tree.configuration.Strategy;
|
import org.apache.spark.mllib.tree.configuration.Strategy;
|
||||||
import org.apache.spark.mllib.tree.impurity.Gini;
|
import org.apache.spark.mllib.tree.impurity.Gini;
|
||||||
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
|
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
|
||||||
|
|
||||||
public class JavaDecisionTreeSuite implements Serializable {
|
public class JavaDecisionTreeSuite implements Serializable {
|
||||||
private transient JavaSparkContext sc;
|
private transient SparkSession spark;
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
|
spark = SparkSession.builder()
|
||||||
|
.master("local")
|
||||||
|
.appName("JavaDecisionTreeSuite")
|
||||||
|
.getOrCreate();
|
||||||
|
jsc = new JavaSparkContext(spark.sparkContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void tearDown() {
|
public void tearDown() {
|
||||||
sc.stop();
|
spark.stop();
|
||||||
sc = null;
|
spark = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
|
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
|
||||||
int numCorrect = 0;
|
int numCorrect = 0;
|
||||||
for (LabeledPoint point: validationData) {
|
for (LabeledPoint point : validationData) {
|
||||||
Double prediction = model.predict(point.features());
|
Double prediction = model.predict(point.features());
|
||||||
if (prediction == point.label()) {
|
if (prediction == point.label()) {
|
||||||
numCorrect++;
|
numCorrect++;
|
||||||
|
@ -65,7 +71,7 @@ public class JavaDecisionTreeSuite implements Serializable {
|
||||||
@Test
|
@Test
|
||||||
public void runDTUsingConstructor() {
|
public void runDTUsingConstructor() {
|
||||||
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
|
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
|
||||||
JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
|
JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
|
||||||
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
|
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
|
||||||
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
|
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
|
||||||
|
|
||||||
|
@ -73,7 +79,7 @@ public class JavaDecisionTreeSuite implements Serializable {
|
||||||
int numClasses = 2;
|
int numClasses = 2;
|
||||||
int maxBins = 100;
|
int maxBins = 100;
|
||||||
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
|
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
|
||||||
maxBins, categoricalFeaturesInfo);
|
maxBins, categoricalFeaturesInfo);
|
||||||
|
|
||||||
DecisionTree learner = new DecisionTree(strategy);
|
DecisionTree learner = new DecisionTree(strategy);
|
||||||
DecisionTreeModel model = learner.run(rdd.rdd());
|
DecisionTreeModel model = learner.run(rdd.rdd());
|
||||||
|
@ -85,7 +91,7 @@ public class JavaDecisionTreeSuite implements Serializable {
|
||||||
@Test
|
@Test
|
||||||
public void runDTUsingStaticMethods() {
|
public void runDTUsingStaticMethods() {
|
||||||
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
|
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
|
||||||
JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
|
JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
|
||||||
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
|
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
|
||||||
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
|
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
|
||||||
|
|
||||||
|
@ -93,7 +99,7 @@ public class JavaDecisionTreeSuite implements Serializable {
|
||||||
int numClasses = 2;
|
int numClasses = 2;
|
||||||
int maxBins = 100;
|
int maxBins = 100;
|
||||||
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
|
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
|
||||||
maxBins, categoricalFeaturesInfo);
|
maxBins, categoricalFeaturesInfo);
|
||||||
|
|
||||||
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
|
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
|
||||||
|
|
||||||
|
|
|
@ -183,7 +183,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
}
|
}
|
||||||
|
|
||||||
test("pipeline validateParams") {
|
test("pipeline validateParams") {
|
||||||
val df = sqlContext.createDataFrame(
|
val df = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
|
(1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
|
||||||
(2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
|
(2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
|
||||||
|
|
|
@ -32,7 +32,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
test("extractLabeledPoints") {
|
test("extractLabeledPoints") {
|
||||||
def getTestData(labels: Seq[Double]): DataFrame = {
|
def getTestData(labels: Seq[Double]): DataFrame = {
|
||||||
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
|
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
|
||||||
sqlContext.createDataFrame(data)
|
spark.createDataFrame(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
val c = new MockClassifier
|
val c = new MockClassifier
|
||||||
|
@ -72,7 +72,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
test("getNumClasses") {
|
test("getNumClasses") {
|
||||||
def getTestData(labels: Seq[Double]): DataFrame = {
|
def getTestData(labels: Seq[Double]): DataFrame = {
|
||||||
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
|
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
|
||||||
sqlContext.createDataFrame(data)
|
spark.createDataFrame(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
val c = new MockClassifier
|
val c = new MockClassifier
|
||||||
|
|
|
@ -337,13 +337,13 @@ class DecisionTreeClassifierSuite
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val dt = new DecisionTreeClassifier().setMaxDepth(1)
|
val dt = new DecisionTreeClassifier().setMaxDepth(1)
|
||||||
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
|
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
|
||||||
dt, isClassification = true, sqlContext) { (expected, actual) =>
|
dt, isClassification = true, spark) { (expected, actual) =>
|
||||||
TreeTests.checkEqual(expected, actual)
|
TreeTests.checkEqual(expected, actual)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Fitting without numClasses in metadata") {
|
test("Fitting without numClasses in metadata") {
|
||||||
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
|
val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
|
||||||
val dt = new DecisionTreeClassifier().setMaxDepth(1)
|
val dt = new DecisionTreeClassifier().setMaxDepth(1)
|
||||||
dt.fit(df)
|
dt.fit(df)
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val gbt = new GBTClassifier().setMaxDepth(1)
|
val gbt = new GBTClassifier().setMaxDepth(1)
|
||||||
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
|
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
|
||||||
gbt, isClassification = true, sqlContext) { (expected, actual) =>
|
gbt, isClassification = true, spark) { (expected, actual) =>
|
||||||
TreeTests.checkEqual(expected, actual)
|
TreeTests.checkEqual(expected, actual)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
*/
|
*/
|
||||||
|
|
||||||
test("Fitting without numClasses in metadata") {
|
test("Fitting without numClasses in metadata") {
|
||||||
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
|
val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
|
||||||
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
|
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
|
||||||
gbt.fit(df)
|
gbt.fit(df)
|
||||||
}
|
}
|
||||||
|
@ -138,7 +138,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
test("extractLabeledPoints with bad data") {
|
test("extractLabeledPoints with bad data") {
|
||||||
def getTestData(labels: Seq[Double]): DataFrame = {
|
def getTestData(labels: Seq[Double]): DataFrame = {
|
||||||
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
|
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
|
||||||
sqlContext.createDataFrame(data)
|
spark.createDataFrame(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
|
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class LogisticRegressionSuite
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
||||||
dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
|
dataset = spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
|
||||||
|
|
||||||
binaryDataset = {
|
binaryDataset = {
|
||||||
val nPoints = 10000
|
val nPoints = 10000
|
||||||
|
@ -54,7 +54,7 @@ class LogisticRegressionSuite
|
||||||
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
|
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
|
||||||
addIntercept = true, nPoints, 42)
|
addIntercept = true, nPoints, 42)
|
||||||
|
|
||||||
sqlContext.createDataFrame(sc.parallelize(testData, 4))
|
spark.createDataFrame(sc.parallelize(testData, 4))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ class LogisticRegressionSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("logistic regression: Predictor, Classifier methods") {
|
test("logistic regression: Predictor, Classifier methods") {
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
val lr = new LogisticRegression
|
val lr = new LogisticRegression
|
||||||
|
|
||||||
val model = lr.fit(dataset)
|
val model = lr.fit(dataset)
|
||||||
|
@ -864,8 +864,8 @@ class LogisticRegressionSuite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(sqlContext.createDataFrame(sc.parallelize(data1, 4)),
|
(spark.createDataFrame(sc.parallelize(data1, 4)),
|
||||||
sqlContext.createDataFrame(sc.parallelize(data2, 4)))
|
spark.createDataFrame(sc.parallelize(data2, 4)))
|
||||||
}
|
}
|
||||||
|
|
||||||
val trainer1a = (new LogisticRegression).setFitIntercept(true)
|
val trainer1a = (new LogisticRegression).setFitIntercept(true)
|
||||||
|
@ -938,7 +938,7 @@ class LogisticRegressionSuite
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val lr = new LogisticRegression().setMaxIter(1)
|
val lr = new LogisticRegression().setMaxIter(1)
|
||||||
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
|
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
|
||||||
lr, isClassification = true, sqlContext) { (expected, actual) =>
|
lr, isClassification = true, spark) { (expected, actual) =>
|
||||||
assert(expected.intercept === actual.intercept)
|
assert(expected.intercept === actual.intercept)
|
||||||
assert(expected.coefficients.toArray === actual.coefficients.toArray)
|
assert(expected.coefficients.toArray === actual.coefficients.toArray)
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ class MultilayerPerceptronClassifierSuite
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
||||||
dataset = sqlContext.createDataFrame(Seq(
|
dataset = spark.createDataFrame(Seq(
|
||||||
(Vectors.dense(0.0, 0.0), 0.0),
|
(Vectors.dense(0.0, 0.0), 0.0),
|
||||||
(Vectors.dense(0.0, 1.0), 1.0),
|
(Vectors.dense(0.0, 1.0), 1.0),
|
||||||
(Vectors.dense(1.0, 0.0), 1.0),
|
(Vectors.dense(1.0, 0.0), 1.0),
|
||||||
|
@ -77,7 +77,7 @@ class MultilayerPerceptronClassifierSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Test setWeights by training restart") {
|
test("Test setWeights by training restart") {
|
||||||
val dataFrame = sqlContext.createDataFrame(Seq(
|
val dataFrame = spark.createDataFrame(Seq(
|
||||||
(Vectors.dense(0.0, 0.0), 0.0),
|
(Vectors.dense(0.0, 0.0), 0.0),
|
||||||
(Vectors.dense(0.0, 1.0), 1.0),
|
(Vectors.dense(0.0, 1.0), 1.0),
|
||||||
(Vectors.dense(1.0, 0.0), 1.0),
|
(Vectors.dense(1.0, 0.0), 1.0),
|
||||||
|
@ -113,7 +113,7 @@ class MultilayerPerceptronClassifierSuite
|
||||||
// the input seed is somewhat magic, to make this test pass
|
// the input seed is somewhat magic, to make this test pass
|
||||||
val rdd = sc.parallelize(generateMultinomialLogisticInput(
|
val rdd = sc.parallelize(generateMultinomialLogisticInput(
|
||||||
coefficients, xMean, xVariance, true, nPoints, 1), 2)
|
coefficients, xMean, xVariance, true, nPoints, 1), 2)
|
||||||
val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
|
val dataFrame = spark.createDataFrame(rdd).toDF("label", "features")
|
||||||
val numClasses = 3
|
val numClasses = 3
|
||||||
val numIterations = 100
|
val numIterations = 100
|
||||||
val layers = Array[Int](4, 5, 4, numClasses)
|
val layers = Array[Int](4, 5, 4, numClasses)
|
||||||
|
@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
|
||||||
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
|
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
|
||||||
MLTestingUtils.checkNumericTypes[
|
MLTestingUtils.checkNumericTypes[
|
||||||
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
|
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
|
||||||
mpc, isClassification = true, sqlContext) { (expected, actual) =>
|
mpc, isClassification = true, spark) { (expected, actual) =>
|
||||||
assert(expected.layers === actual.layers)
|
assert(expected.layers === actual.layers)
|
||||||
assert(expected.weights === actual.weights)
|
assert(expected.weights === actual.weights)
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
Array(0.10, 0.10, 0.70, 0.10) // label 2
|
Array(0.10, 0.10, 0.70, 0.10) // label 2
|
||||||
).map(_.map(math.log))
|
).map(_.map(math.log))
|
||||||
|
|
||||||
dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
|
dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
|
||||||
}
|
}
|
||||||
|
|
||||||
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
|
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
|
||||||
|
@ -127,7 +127,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
val pi = Vectors.dense(piArray)
|
val pi = Vectors.dense(piArray)
|
||||||
val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
|
val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
|
||||||
|
|
||||||
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
|
val testDataset = spark.createDataFrame(generateNaiveBayesInput(
|
||||||
piArray, thetaArray, nPoints, 42, "multinomial"))
|
piArray, thetaArray, nPoints, 42, "multinomial"))
|
||||||
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
|
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
|
||||||
val model = nb.fit(testDataset)
|
val model = nb.fit(testDataset)
|
||||||
|
@ -135,7 +135,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
validateModelFit(pi, theta, model)
|
validateModelFit(pi, theta, model)
|
||||||
assert(model.hasParent)
|
assert(model.hasParent)
|
||||||
|
|
||||||
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
|
val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
|
||||||
piArray, thetaArray, nPoints, 17, "multinomial"))
|
piArray, thetaArray, nPoints, 17, "multinomial"))
|
||||||
|
|
||||||
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
|
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
|
||||||
|
@ -157,7 +157,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
val pi = Vectors.dense(piArray)
|
val pi = Vectors.dense(piArray)
|
||||||
val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
|
val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
|
||||||
|
|
||||||
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
|
val testDataset = spark.createDataFrame(generateNaiveBayesInput(
|
||||||
piArray, thetaArray, nPoints, 45, "bernoulli"))
|
piArray, thetaArray, nPoints, 45, "bernoulli"))
|
||||||
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
|
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
|
||||||
val model = nb.fit(testDataset)
|
val model = nb.fit(testDataset)
|
||||||
|
@ -165,7 +165,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
validateModelFit(pi, theta, model)
|
validateModelFit(pi, theta, model)
|
||||||
assert(model.hasParent)
|
assert(model.hasParent)
|
||||||
|
|
||||||
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
|
val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
|
||||||
piArray, thetaArray, nPoints, 20, "bernoulli"))
|
piArray, thetaArray, nPoints, 20, "bernoulli"))
|
||||||
|
|
||||||
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
|
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
|
||||||
|
@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val nb = new NaiveBayes()
|
val nb = new NaiveBayes()
|
||||||
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
|
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
|
||||||
nb, isClassification = true, sqlContext) { (expected, actual) =>
|
nb, isClassification = true, spark) { (expected, actual) =>
|
||||||
assert(expected.pi === actual.pi)
|
assert(expected.pi === actual.pi)
|
||||||
assert(expected.theta === actual.theta)
|
assert(expected.theta === actual.theta)
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
|
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
|
||||||
rdd = sc.parallelize(generateMultinomialLogisticInput(
|
rdd = sc.parallelize(generateMultinomialLogisticInput(
|
||||||
coefficients, xMean, xVariance, true, nPoints, 42), 2)
|
coefficients, xMean, xVariance, true, nPoints, 42), 2)
|
||||||
dataset = sqlContext.createDataFrame(rdd)
|
dataset = spark.createDataFrame(rdd)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("params") {
|
test("params") {
|
||||||
|
@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
|
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
|
||||||
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
|
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
|
||||||
ovr, isClassification = true, sqlContext) { (expected, actual) =>
|
ovr, isClassification = true, spark) { (expected, actual) =>
|
||||||
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
|
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
|
||||||
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
|
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
|
||||||
assert(expectedModels.length === actualModels.length)
|
assert(expectedModels.length === actualModels.length)
|
||||||
|
|
|
@ -155,7 +155,7 @@ class RandomForestClassifierSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Fitting without numClasses in metadata") {
|
test("Fitting without numClasses in metadata") {
|
||||||
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
|
val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
|
||||||
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
|
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
|
||||||
rf.fit(df)
|
rf.fit(df)
|
||||||
}
|
}
|
||||||
|
@ -189,7 +189,7 @@ class RandomForestClassifierSuite
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val rf = new RandomForestClassifier().setMaxDepth(1)
|
val rf = new RandomForestClassifier().setMaxDepth(1)
|
||||||
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
|
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
|
||||||
rf, isClassification = true, sqlContext) { (expected, actual) =>
|
rf, isClassification = true, spark) { (expected, actual) =>
|
||||||
TreeTests.checkEqual(expected, actual)
|
TreeTests.checkEqual(expected, actual)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ class BisectingKMeansSuite
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
|
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("default parameters") {
|
test("default parameters") {
|
||||||
|
|
|
@ -32,7 +32,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
||||||
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
|
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("default parameters") {
|
test("default parameters") {
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
|
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
|
||||||
|
|
||||||
private[clustering] case class TestRow(features: Vector)
|
private[clustering] case class TestRow(features: Vector)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
||||||
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
|
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("default parameters") {
|
test("default parameters") {
|
||||||
|
@ -142,11 +142,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
|
||||||
}
|
}
|
||||||
|
|
||||||
object KMeansSuite {
|
object KMeansSuite {
|
||||||
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
|
def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = {
|
||||||
val sc = sql.sparkContext
|
val sc = spark.sparkContext
|
||||||
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
|
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
|
||||||
.map(v => new TestRow(v))
|
.map(v => new TestRow(v))
|
||||||
sql.createDataFrame(rdd)
|
spark.createDataFrame(rdd)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,30 +17,30 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.clustering
|
package org.apache.spark.ml.clustering
|
||||||
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
import org.apache.hadoop.fs.Path
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
|
import org.apache.spark.sql._
|
||||||
|
|
||||||
|
|
||||||
object LDASuite {
|
object LDASuite {
|
||||||
def generateLDAData(
|
def generateLDAData(
|
||||||
sql: SQLContext,
|
spark: SparkSession,
|
||||||
rows: Int,
|
rows: Int,
|
||||||
k: Int,
|
k: Int,
|
||||||
vocabSize: Int): DataFrame = {
|
vocabSize: Int): DataFrame = {
|
||||||
val avgWC = 1 // average instances of each word in a doc
|
val avgWC = 1 // average instances of each word in a doc
|
||||||
val sc = sql.sparkContext
|
val sc = spark.sparkContext
|
||||||
val rng = new java.util.Random()
|
val rng = new java.util.Random()
|
||||||
rng.setSeed(1)
|
rng.setSeed(1)
|
||||||
val rdd = sc.parallelize(1 to rows).map { i =>
|
val rdd = sc.parallelize(1 to rows).map { i =>
|
||||||
Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
|
Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
|
||||||
}.map(v => new TestRow(v))
|
}.map(v => new TestRow(v))
|
||||||
sql.createDataFrame(rdd)
|
spark.createDataFrame(rdd)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -68,7 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize)
|
dataset = LDASuite.generateLDAData(spark, 50, k, vocabSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("default parameters") {
|
test("default parameters") {
|
||||||
|
@ -140,7 +140,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
new LDA().setTopicConcentration(-1.1)
|
new LDA().setTopicConcentration(-1.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
val dummyDF = sqlContext.createDataFrame(Seq(
|
val dummyDF = spark.createDataFrame(Seq(
|
||||||
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
|
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
|
||||||
// validate parameters
|
// validate parameters
|
||||||
lda.transformSchema(dummyDF.schema)
|
lda.transformSchema(dummyDF.schema)
|
||||||
|
@ -274,7 +274,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
// There should be 1 checkpoint remaining.
|
// There should be 1 checkpoint remaining.
|
||||||
assert(model.getCheckpointFiles.length === 1)
|
assert(model.getCheckpointFiles.length === 1)
|
||||||
val checkpointFile = new Path(model.getCheckpointFiles.head)
|
val checkpointFile = new Path(model.getCheckpointFiles.head)
|
||||||
val fs = checkpointFile.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
|
val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration)
|
||||||
assert(fs.exists(checkpointFile))
|
assert(fs.exists(checkpointFile))
|
||||||
model.deleteCheckpointFiles()
|
model.deleteCheckpointFiles()
|
||||||
assert(model.getCheckpointFiles.isEmpty)
|
assert(model.getCheckpointFiles.isEmpty)
|
||||||
|
|
|
@ -42,21 +42,21 @@ class BinaryClassificationEvaluatorSuite
|
||||||
val evaluator = new BinaryClassificationEvaluator()
|
val evaluator = new BinaryClassificationEvaluator()
|
||||||
.setMetricName("areaUnderPR")
|
.setMetricName("areaUnderPR")
|
||||||
|
|
||||||
val vectorDF = sqlContext.createDataFrame(Seq(
|
val vectorDF = spark.createDataFrame(Seq(
|
||||||
(0d, Vectors.dense(12, 2.5)),
|
(0d, Vectors.dense(12, 2.5)),
|
||||||
(1d, Vectors.dense(1, 3)),
|
(1d, Vectors.dense(1, 3)),
|
||||||
(0d, Vectors.dense(10, 2))
|
(0d, Vectors.dense(10, 2))
|
||||||
)).toDF("label", "rawPrediction")
|
)).toDF("label", "rawPrediction")
|
||||||
assert(evaluator.evaluate(vectorDF) === 1.0)
|
assert(evaluator.evaluate(vectorDF) === 1.0)
|
||||||
|
|
||||||
val doubleDF = sqlContext.createDataFrame(Seq(
|
val doubleDF = spark.createDataFrame(Seq(
|
||||||
(0d, 0d),
|
(0d, 0d),
|
||||||
(1d, 1d),
|
(1d, 1d),
|
||||||
(0d, 0d)
|
(0d, 0d)
|
||||||
)).toDF("label", "rawPrediction")
|
)).toDF("label", "rawPrediction")
|
||||||
assert(evaluator.evaluate(doubleDF) === 1.0)
|
assert(evaluator.evaluate(doubleDF) === 1.0)
|
||||||
|
|
||||||
val stringDF = sqlContext.createDataFrame(Seq(
|
val stringDF = spark.createDataFrame(Seq(
|
||||||
(0d, "0d"),
|
(0d, "0d"),
|
||||||
(1d, "1d"),
|
(1d, "1d"),
|
||||||
(0d, "0d")
|
(0d, "0d")
|
||||||
|
@ -71,6 +71,6 @@ class BinaryClassificationEvaluatorSuite
|
||||||
|
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
|
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
|
||||||
MLTestingUtils.checkNumericTypes(evaluator, sqlContext)
|
MLTestingUtils.checkNumericTypes(evaluator, spark)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,6 @@ class MulticlassClassificationEvaluatorSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext)
|
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,7 @@ class RegressionEvaluatorSuite
|
||||||
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
|
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
|
||||||
* .saveAsTextFile("path")
|
* .saveAsTextFile("path")
|
||||||
*/
|
*/
|
||||||
val dataset = sqlContext.createDataFrame(
|
val dataset = spark.createDataFrame(
|
||||||
sc.parallelize(LinearDataGenerator.generateLinearInput(
|
sc.parallelize(LinearDataGenerator.generateLinearInput(
|
||||||
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
|
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
|
||||||
|
|
||||||
|
@ -85,6 +85,6 @@ class RegressionEvaluatorSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("should support all NumericType labels and not support other types") {
|
test("should support all NumericType labels and not support other types") {
|
||||||
MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext)
|
MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
|
|
||||||
test("Binarize continuous features with default parameter") {
|
test("Binarize continuous features with default parameter") {
|
||||||
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
|
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
|
||||||
val dataFrame: DataFrame = sqlContext.createDataFrame(
|
val dataFrame: DataFrame = spark.createDataFrame(
|
||||||
data.zip(defaultBinarized)).toDF("feature", "expected")
|
data.zip(defaultBinarized)).toDF("feature", "expected")
|
||||||
|
|
||||||
val binarizer: Binarizer = new Binarizer()
|
val binarizer: Binarizer = new Binarizer()
|
||||||
|
@ -55,7 +55,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
test("Binarize continuous features with setter") {
|
test("Binarize continuous features with setter") {
|
||||||
val threshold: Double = 0.2
|
val threshold: Double = 0.2
|
||||||
val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
|
val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
|
||||||
val dataFrame: DataFrame = sqlContext.createDataFrame(
|
val dataFrame: DataFrame = spark.createDataFrame(
|
||||||
data.zip(thresholdBinarized)).toDF("feature", "expected")
|
data.zip(thresholdBinarized)).toDF("feature", "expected")
|
||||||
|
|
||||||
val binarizer: Binarizer = new Binarizer()
|
val binarizer: Binarizer = new Binarizer()
|
||||||
|
@ -71,7 +71,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
|
|
||||||
test("Binarize vector of continuous features with default parameter") {
|
test("Binarize vector of continuous features with default parameter") {
|
||||||
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
|
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
|
||||||
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
|
val dataFrame: DataFrame = spark.createDataFrame(Seq(
|
||||||
(Vectors.dense(data), Vectors.dense(defaultBinarized))
|
(Vectors.dense(data), Vectors.dense(defaultBinarized))
|
||||||
)).toDF("feature", "expected")
|
)).toDF("feature", "expected")
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
test("Binarize vector of continuous features with setter") {
|
test("Binarize vector of continuous features with setter") {
|
||||||
val threshold: Double = 0.2
|
val threshold: Double = 0.2
|
||||||
val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
|
val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
|
||||||
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
|
val dataFrame: DataFrame = spark.createDataFrame(Seq(
|
||||||
(Vectors.dense(data), Vectors.dense(defaultBinarized))
|
(Vectors.dense(data), Vectors.dense(defaultBinarized))
|
||||||
)).toDF("feature", "expected")
|
)).toDF("feature", "expected")
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
val validData = Array(-0.5, -0.3, 0.0, 0.2)
|
val validData = Array(-0.5, -0.3, 0.0, 0.2)
|
||||||
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
|
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
|
||||||
val dataFrame: DataFrame =
|
val dataFrame: DataFrame =
|
||||||
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
|
spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
|
||||||
|
|
||||||
val bucketizer: Bucketizer = new Bucketizer()
|
val bucketizer: Bucketizer = new Bucketizer()
|
||||||
.setInputCol("feature")
|
.setInputCol("feature")
|
||||||
|
@ -55,13 +55,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
// Check for exceptions when using a set of invalid feature values.
|
// Check for exceptions when using a set of invalid feature values.
|
||||||
val invalidData1: Array[Double] = Array(-0.9) ++ validData
|
val invalidData1: Array[Double] = Array(-0.9) ++ validData
|
||||||
val invalidData2 = Array(0.51) ++ validData
|
val invalidData2 = Array(0.51) ++ validData
|
||||||
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
|
val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
|
||||||
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
|
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
|
||||||
intercept[SparkException] {
|
intercept[SparkException] {
|
||||||
bucketizer.transform(badDF1).collect()
|
bucketizer.transform(badDF1).collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
|
val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
|
||||||
withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
|
withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
|
||||||
intercept[SparkException] {
|
intercept[SparkException] {
|
||||||
bucketizer.transform(badDF2).collect()
|
bucketizer.transform(badDF2).collect()
|
||||||
|
@ -74,7 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
|
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
|
||||||
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
|
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
|
||||||
val dataFrame: DataFrame =
|
val dataFrame: DataFrame =
|
||||||
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
|
spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
|
||||||
|
|
||||||
val bucketizer: Bucketizer = new Bucketizer()
|
val bucketizer: Bucketizer = new Bucketizer()
|
||||||
.setInputCol("feature")
|
.setInputCol("feature")
|
||||||
|
|
|
@ -24,14 +24,17 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{Row, SQLContext}
|
import org.apache.spark.sql.{Row, SparkSession}
|
||||||
|
|
||||||
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
|
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
with DefaultReadWriteTest {
|
with DefaultReadWriteTest {
|
||||||
|
|
||||||
test("Test Chi-Square selector") {
|
test("Test Chi-Square selector") {
|
||||||
val sqlContext = SQLContext.getOrCreate(sc)
|
val spark = SparkSession.builder
|
||||||
import sqlContext.implicits._
|
.master("local[2]")
|
||||||
|
.appName("ChiSqSelectorSuite")
|
||||||
|
.getOrCreate()
|
||||||
|
import spark.implicits._
|
||||||
|
|
||||||
val data = Seq(
|
val data = Seq(
|
||||||
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
|
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
|
||||||
|
|
|
@ -35,7 +35,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
private def split(s: String): Seq[String] = s.split("\\s+")
|
private def split(s: String): Seq[String] = s.split("\\s+")
|
||||||
|
|
||||||
test("CountVectorizerModel common cases") {
|
test("CountVectorizerModel common cases") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a b c d"),
|
(0, split("a b c d"),
|
||||||
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
|
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
|
||||||
(1, split("a b b c d a"),
|
(1, split("a b b c d a"),
|
||||||
|
@ -55,7 +55,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("CountVectorizer common cases") {
|
test("CountVectorizer common cases") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a b c d e"),
|
(0, split("a b c d e"),
|
||||||
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
|
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
|
||||||
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
|
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
|
||||||
|
@ -76,7 +76,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("CountVectorizer vocabSize and minDF") {
|
test("CountVectorizer vocabSize and minDF") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
|
(0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
|
||||||
(1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
|
(1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
|
||||||
(2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
|
(2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
|
||||||
|
@ -118,7 +118,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
|
|
||||||
test("CountVectorizer throws exception when vocab is empty") {
|
test("CountVectorizer throws exception when vocab is empty") {
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a a b b c c")),
|
(0, split("a a b b c c")),
|
||||||
(1, split("aa bb cc")))
|
(1, split("aa bb cc")))
|
||||||
).toDF("id", "words")
|
).toDF("id", "words")
|
||||||
|
@ -132,7 +132,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("CountVectorizerModel with minTF count") {
|
test("CountVectorizerModel with minTF count") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
|
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
|
||||||
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
|
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
|
||||||
(2, split("a"), Vectors.sparse(4, Seq())),
|
(2, split("a"), Vectors.sparse(4, Seq())),
|
||||||
|
@ -151,7 +151,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("CountVectorizerModel with minTF freq") {
|
test("CountVectorizerModel with minTF freq") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
|
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
|
||||||
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
|
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
|
||||||
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
|
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
|
||||||
|
@ -170,7 +170,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("CountVectorizerModel and CountVectorizer with binary") {
|
test("CountVectorizerModel and CountVectorizer with binary") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, split("a a a a b b b b c d"),
|
(0, split("a a a a b b b b c d"),
|
||||||
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
|
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
|
||||||
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
|
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
|
||||||
|
|
|
@ -63,7 +63,7 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
}
|
}
|
||||||
val expectedResult = Vectors.dense(expectedResultBuffer)
|
val expectedResult = Vectors.dense(expectedResultBuffer)
|
||||||
|
|
||||||
val dataset = sqlContext.createDataFrame(Seq(
|
val dataset = spark.createDataFrame(Seq(
|
||||||
DCTTestData(data, expectedResult)
|
DCTTestData(data, expectedResult)
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
}
|
}
|
||||||
|
|
||||||
test("hashingTF") {
|
test("hashingTF") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, "a a b b c d".split(" ").toSeq)
|
(0, "a a b b c d".split(" ").toSeq)
|
||||||
)).toDF("id", "words")
|
)).toDF("id", "words")
|
||||||
val n = 100
|
val n = 100
|
||||||
|
@ -54,7 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
|
||||||
}
|
}
|
||||||
|
|
||||||
test("applying binary term freqs") {
|
test("applying binary term freqs") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, "a a b c c c".split(" ").toSeq)
|
(0, "a a b c c c".split(" ").toSeq)
|
||||||
)).toDF("id", "words")
|
)).toDF("id", "words")
|
||||||
val n = 100
|
val n = 100
|
||||||
|
|
|
@ -60,7 +60,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
})
|
})
|
||||||
val expected = scaleDataWithIDF(data, idf)
|
val expected = scaleDataWithIDF(data, idf)
|
||||||
|
|
||||||
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
||||||
|
|
||||||
val idfModel = new IDF()
|
val idfModel = new IDF()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
@ -86,7 +86,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
})
|
})
|
||||||
val expected = scaleDataWithIDF(data, idf)
|
val expected = scaleDataWithIDF(data, idf)
|
||||||
|
|
||||||
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
||||||
|
|
||||||
val idfModel = new IDF()
|
val idfModel = new IDF()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
|
|
@ -59,7 +59,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
|
||||||
}
|
}
|
||||||
|
|
||||||
test("numeric interaction") {
|
test("numeric interaction") {
|
||||||
val data = sqlContext.createDataFrame(
|
val data = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(2, Vectors.dense(3.0, 4.0)),
|
(2, Vectors.dense(3.0, 4.0)),
|
||||||
(1, Vectors.dense(1.0, 5.0)))
|
(1, Vectors.dense(1.0, 5.0)))
|
||||||
|
@ -74,7 +74,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
|
||||||
col("b").as("b", groupAttr.toMetadata()))
|
col("b").as("b", groupAttr.toMetadata()))
|
||||||
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
|
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
|
||||||
val res = trans.transform(df)
|
val res = trans.transform(df)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
|
(2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
|
||||||
(1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)))
|
(1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)))
|
||||||
|
@ -90,7 +90,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
|
||||||
}
|
}
|
||||||
|
|
||||||
test("nominal interaction") {
|
test("nominal interaction") {
|
||||||
val data = sqlContext.createDataFrame(
|
val data = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(2, Vectors.dense(3.0, 4.0)),
|
(2, Vectors.dense(3.0, 4.0)),
|
||||||
(1, Vectors.dense(1.0, 5.0)))
|
(1, Vectors.dense(1.0, 5.0)))
|
||||||
|
@ -106,7 +106,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
|
||||||
col("b").as("b", groupAttr.toMetadata()))
|
col("b").as("b", groupAttr.toMetadata()))
|
||||||
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
|
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
|
||||||
val res = trans.transform(df)
|
val res = trans.transform(df)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
|
(2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
|
||||||
(1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)))
|
(1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)))
|
||||||
|
@ -126,7 +126,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
|
||||||
}
|
}
|
||||||
|
|
||||||
test("default attr names") {
|
test("default attr names") {
|
||||||
val data = sqlContext.createDataFrame(
|
val data = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(2, Vectors.dense(0.0, 4.0), 1.0),
|
(2, Vectors.dense(0.0, 4.0), 1.0),
|
||||||
(1, Vectors.dense(1.0, 5.0), 10.0))
|
(1, Vectors.dense(1.0, 5.0), 10.0))
|
||||||
|
@ -142,7 +142,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
|
||||||
col("c").as("c", NumericAttribute.defaultAttr.toMetadata()))
|
col("c").as("c", NumericAttribute.defaultAttr.toMetadata()))
|
||||||
val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features")
|
val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features")
|
||||||
val res = trans.transform(df)
|
val res = trans.transform(df)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)),
|
(2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)),
|
||||||
(1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)))
|
(1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)))
|
||||||
|
|
|
@ -36,7 +36,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
|
||||||
Vectors.sparse(3, Array(0, 2), Array(-1, -1)),
|
Vectors.sparse(3, Array(0, 2), Array(-1, -1)),
|
||||||
Vectors.sparse(3, Array(0), Array(-0.75)))
|
Vectors.sparse(3, Array(0), Array(-0.75)))
|
||||||
|
|
||||||
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
||||||
val scaler = new MaxAbsScaler()
|
val scaler = new MaxAbsScaler()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
.setOutputCol("scaled")
|
.setOutputCol("scaled")
|
||||||
|
|
|
@ -38,7 +38,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
|
||||||
Vectors.sparse(3, Array(0, 2), Array(5, 5)),
|
Vectors.sparse(3, Array(0, 2), Array(5, 5)),
|
||||||
Vectors.sparse(3, Array(0), Array(-2.5)))
|
Vectors.sparse(3, Array(0), Array(-2.5)))
|
||||||
|
|
||||||
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
|
||||||
val scaler = new MinMaxScaler()
|
val scaler = new MinMaxScaler()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
.setOutputCol("scaled")
|
.setOutputCol("scaled")
|
||||||
|
@ -57,7 +57,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
|
||||||
|
|
||||||
test("MinMaxScaler arguments max must be larger than min") {
|
test("MinMaxScaler arguments max must be larger than min") {
|
||||||
withClue("arguments max must be larger than min") {
|
withClue("arguments max must be larger than min") {
|
||||||
val dummyDF = sqlContext.createDataFrame(Seq(
|
val dummyDF = spark.createDataFrame(Seq(
|
||||||
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
|
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
|
val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
|
||||||
|
|
|
@ -34,7 +34,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
|
||||||
val nGram = new NGram()
|
val nGram = new NGram()
|
||||||
.setInputCol("inputTokens")
|
.setInputCol("inputTokens")
|
||||||
.setOutputCol("nGrams")
|
.setOutputCol("nGrams")
|
||||||
val dataset = sqlContext.createDataFrame(Seq(
|
val dataset = spark.createDataFrame(Seq(
|
||||||
NGramTestData(
|
NGramTestData(
|
||||||
Array("Test", "for", "ngram", "."),
|
Array("Test", "for", "ngram", "."),
|
||||||
Array("Test for", "for ngram", "ngram .")
|
Array("Test for", "for ngram", "ngram .")
|
||||||
|
@ -47,7 +47,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
|
||||||
.setInputCol("inputTokens")
|
.setInputCol("inputTokens")
|
||||||
.setOutputCol("nGrams")
|
.setOutputCol("nGrams")
|
||||||
.setN(4)
|
.setN(4)
|
||||||
val dataset = sqlContext.createDataFrame(Seq(
|
val dataset = spark.createDataFrame(Seq(
|
||||||
NGramTestData(
|
NGramTestData(
|
||||||
Array("a", "b", "c", "d", "e"),
|
Array("a", "b", "c", "d", "e"),
|
||||||
Array("a b c d", "b c d e")
|
Array("a b c d", "b c d e")
|
||||||
|
@ -60,7 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
|
||||||
.setInputCol("inputTokens")
|
.setInputCol("inputTokens")
|
||||||
.setOutputCol("nGrams")
|
.setOutputCol("nGrams")
|
||||||
.setN(4)
|
.setN(4)
|
||||||
val dataset = sqlContext.createDataFrame(Seq(
|
val dataset = spark.createDataFrame(Seq(
|
||||||
NGramTestData(
|
NGramTestData(
|
||||||
Array(),
|
Array(),
|
||||||
Array()
|
Array()
|
||||||
|
@ -73,7 +73,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
|
||||||
.setInputCol("inputTokens")
|
.setInputCol("inputTokens")
|
||||||
.setOutputCol("nGrams")
|
.setOutputCol("nGrams")
|
||||||
.setN(6)
|
.setN(6)
|
||||||
val dataset = sqlContext.createDataFrame(Seq(
|
val dataset = spark.createDataFrame(Seq(
|
||||||
NGramTestData(
|
NGramTestData(
|
||||||
Array("a", "b", "c", "d", "e"),
|
Array("a", "b", "c", "d", "e"),
|
||||||
Array()
|
Array()
|
||||||
|
|
|
@ -61,7 +61,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
||||||
Vectors.sparse(3, Seq())
|
Vectors.sparse(3, Seq())
|
||||||
)
|
)
|
||||||
|
|
||||||
dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
|
dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
|
||||||
normalizer = new Normalizer()
|
normalizer = new Normalizer()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
.setOutputCol("normalized_features")
|
.setOutputCol("normalized_features")
|
||||||
|
|
|
@ -32,7 +32,7 @@ class OneHotEncoderSuite
|
||||||
|
|
||||||
def stringIndexed(): DataFrame = {
|
def stringIndexed(): DataFrame = {
|
||||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = spark.createDataFrame(data).toDF("id", "label")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
|
@ -81,7 +81,7 @@ class OneHotEncoderSuite
|
||||||
|
|
||||||
test("input column with ML attribute") {
|
test("input column with ML attribute") {
|
||||||
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
|
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
|
||||||
val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
|
val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
|
||||||
.select(col("size").as("size", attr.toMetadata()))
|
.select(col("size").as("size", attr.toMetadata()))
|
||||||
val encoder = new OneHotEncoder()
|
val encoder = new OneHotEncoder()
|
||||||
.setInputCol("size")
|
.setInputCol("size")
|
||||||
|
@ -94,7 +94,7 @@ class OneHotEncoderSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("input column without ML attribute") {
|
test("input column without ML attribute") {
|
||||||
val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
|
val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
|
||||||
val encoder = new OneHotEncoder()
|
val encoder = new OneHotEncoder()
|
||||||
.setInputCol("index")
|
.setInputCol("index")
|
||||||
.setOutputCol("encoded")
|
.setOutputCol("encoded")
|
||||||
|
|
|
@ -49,7 +49,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
val pc = mat.computePrincipalComponents(3)
|
val pc = mat.computePrincipalComponents(3)
|
||||||
val expected = mat.multiply(pc).rows
|
val expected = mat.multiply(pc).rows
|
||||||
|
|
||||||
val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
|
val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
|
||||||
|
|
||||||
val pca = new PCA()
|
val pca = new PCA()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
|
|
@ -59,7 +59,7 @@ class PolynomialExpansionSuite
|
||||||
Vectors.sparse(19, Array.empty, Array.empty))
|
Vectors.sparse(19, Array.empty, Array.empty))
|
||||||
|
|
||||||
test("Polynomial expansion with default parameter") {
|
test("Polynomial expansion with default parameter") {
|
||||||
val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
|
||||||
|
|
||||||
val polynomialExpansion = new PolynomialExpansion()
|
val polynomialExpansion = new PolynomialExpansion()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
@ -76,7 +76,7 @@ class PolynomialExpansionSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Polynomial expansion with setter") {
|
test("Polynomial expansion with setter") {
|
||||||
val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
|
||||||
|
|
||||||
val polynomialExpansion = new PolynomialExpansion()
|
val polynomialExpansion = new PolynomialExpansion()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
@ -94,7 +94,7 @@ class PolynomialExpansionSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Polynomial expansion with degree 1 is identity on vectors") {
|
test("Polynomial expansion with degree 1 is identity on vectors") {
|
||||||
val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
|
val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
|
||||||
|
|
||||||
val polynomialExpansion = new PolynomialExpansion()
|
val polynomialExpansion = new PolynomialExpansion()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
|
|
@ -32,12 +32,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("transform numeric data") {
|
test("transform numeric data") {
|
||||||
val formula = new RFormula().setFormula("id ~ v1 + v2")
|
val formula = new RFormula().setFormula("id ~ v1 + v2")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
|
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val resultSchema = model.transformSchema(original.schema)
|
val resultSchema = model.transformSchema(original.schema)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
|
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
|
||||||
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
|
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
|
||||||
|
@ -50,7 +50,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("features column already exists") {
|
test("features column already exists") {
|
||||||
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
|
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
|
||||||
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
|
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
formula.fit(original)
|
formula.fit(original)
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("label column already exists") {
|
test("label column already exists") {
|
||||||
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
|
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
|
||||||
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
|
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val resultSchema = model.transformSchema(original.schema)
|
val resultSchema = model.transformSchema(original.schema)
|
||||||
assert(resultSchema.length == 3)
|
assert(resultSchema.length == 3)
|
||||||
|
@ -70,7 +70,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("label column already exists but is not double type") {
|
test("label column already exists but is not double type") {
|
||||||
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
|
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
|
||||||
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
|
val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
model.transformSchema(original.schema)
|
model.transformSchema(original.schema)
|
||||||
|
@ -82,7 +82,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("allow missing label column for test datasets") {
|
test("allow missing label column for test datasets") {
|
||||||
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
|
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
|
||||||
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
|
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val resultSchema = model.transformSchema(original.schema)
|
val resultSchema = model.transformSchema(original.schema)
|
||||||
assert(resultSchema.length == 3)
|
assert(resultSchema.length == 3)
|
||||||
|
@ -91,14 +91,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
}
|
}
|
||||||
|
|
||||||
test("allow empty label") {
|
test("allow empty label") {
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
|
Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
val formula = new RFormula().setFormula("~ a + b")
|
val formula = new RFormula().setFormula("~ a + b")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val resultSchema = model.transformSchema(original.schema)
|
val resultSchema = model.transformSchema(original.schema)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
|
(1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
|
||||||
(4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
|
(4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
|
||||||
|
@ -110,13 +110,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("encodes string terms") {
|
test("encodes string terms") {
|
||||||
val formula = new RFormula().setFormula("id ~ a + b")
|
val formula = new RFormula().setFormula("id ~ a + b")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
|
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val resultSchema = model.transformSchema(original.schema)
|
val resultSchema = model.transformSchema(original.schema)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
|
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
|
||||||
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
|
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
|
||||||
|
@ -129,13 +129,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("index string label") {
|
test("index string label") {
|
||||||
val formula = new RFormula().setFormula("id ~ a + b")
|
val formula = new RFormula().setFormula("id ~ a + b")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
|
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val resultSchema = model.transformSchema(original.schema)
|
val resultSchema = model.transformSchema(original.schema)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
|
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
|
||||||
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
|
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
|
||||||
|
@ -148,7 +148,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("attribute generation") {
|
test("attribute generation") {
|
||||||
val formula = new RFormula().setFormula("id ~ a + b")
|
val formula = new RFormula().setFormula("id ~ a + b")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
|
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
|
@ -165,7 +165,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("vector attribute generation") {
|
test("vector attribute generation") {
|
||||||
val formula = new RFormula().setFormula("id ~ vec")
|
val formula = new RFormula().setFormula("id ~ vec")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
|
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
|
||||||
).toDF("id", "vec")
|
).toDF("id", "vec")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
|
@ -181,7 +181,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("vector attribute generation with unnamed input attrs") {
|
test("vector attribute generation with unnamed input attrs") {
|
||||||
val formula = new RFormula().setFormula("id ~ vec2")
|
val formula = new RFormula().setFormula("id ~ vec2")
|
||||||
val base = sqlContext.createDataFrame(
|
val base = spark.createDataFrame(
|
||||||
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
|
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
|
||||||
).toDF("id", "vec")
|
).toDF("id", "vec")
|
||||||
val metadata = new AttributeGroup(
|
val metadata = new AttributeGroup(
|
||||||
|
@ -203,12 +203,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("numeric interaction") {
|
test("numeric interaction") {
|
||||||
val formula = new RFormula().setFormula("a ~ b:c:d")
|
val formula = new RFormula().setFormula("a ~ b:c:d")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, 2, 4, 2), (2, 3, 4, 1))
|
Seq((1, 2, 4, 2), (2, 3, 4, 1))
|
||||||
).toDF("a", "b", "c", "d")
|
).toDF("a", "b", "c", "d")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
|
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
|
||||||
(2, 3, 4, 1, Vectors.dense(12.0), 2.0))
|
(2, 3, 4, 1, Vectors.dense(12.0), 2.0))
|
||||||
|
@ -223,12 +223,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("factor numeric interaction") {
|
test("factor numeric interaction") {
|
||||||
val formula = new RFormula().setFormula("id ~ a:b")
|
val formula = new RFormula().setFormula("id ~ a:b")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
|
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
|
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
|
||||||
(2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
|
(2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
|
||||||
|
@ -250,12 +250,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("factor factor interaction") {
|
test("factor factor interaction") {
|
||||||
val formula = new RFormula().setFormula("id ~ a:b")
|
val formula = new RFormula().setFormula("id ~ a:b")
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
|
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
val model = formula.fit(original)
|
val model = formula.fit(original)
|
||||||
val result = model.transform(original)
|
val result = model.transform(original)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq(
|
Seq(
|
||||||
(1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
|
(1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
|
||||||
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
|
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
|
||||||
|
@ -299,7 +299,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val dataset = sqlContext.createDataFrame(
|
val dataset = spark.createDataFrame(
|
||||||
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
|
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
|
||||||
).toDF("id", "a", "b")
|
).toDF("id", "a", "b")
|
||||||
|
|
||||||
|
|
|
@ -31,13 +31,13 @@ class SQLTransformerSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("transform numeric data") {
|
test("transform numeric data") {
|
||||||
val original = sqlContext.createDataFrame(
|
val original = spark.createDataFrame(
|
||||||
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
|
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
|
||||||
val sqlTrans = new SQLTransformer().setStatement(
|
val sqlTrans = new SQLTransformer().setStatement(
|
||||||
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
|
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
|
||||||
val result = sqlTrans.transform(original)
|
val result = sqlTrans.transform(original)
|
||||||
val resultSchema = sqlTrans.transformSchema(original.schema)
|
val resultSchema = sqlTrans.transformSchema(original.schema)
|
||||||
val expected = sqlContext.createDataFrame(
|
val expected = spark.createDataFrame(
|
||||||
Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
|
Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
|
||||||
.toDF("id", "v1", "v2", "v3", "v4")
|
.toDF("id", "v1", "v2", "v3", "v4")
|
||||||
assert(result.schema.toString == resultSchema.toString)
|
assert(result.schema.toString == resultSchema.toString)
|
||||||
|
@ -52,7 +52,7 @@ class SQLTransformerSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("transformSchema") {
|
test("transformSchema") {
|
||||||
val df = sqlContext.range(10)
|
val df = spark.range(10)
|
||||||
val outputSchema = new SQLTransformer()
|
val outputSchema = new SQLTransformer()
|
||||||
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
|
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
|
||||||
.transformSchema(df.schema)
|
.transformSchema(df.schema)
|
||||||
|
|
|
@ -73,7 +73,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Standardization with default parameter") {
|
test("Standardization with default parameter") {
|
||||||
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
|
val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
|
||||||
|
|
||||||
val standardScaler0 = new StandardScaler()
|
val standardScaler0 = new StandardScaler()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
@ -84,9 +84,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Standardization with setter") {
|
test("Standardization with setter") {
|
||||||
val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
|
val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
|
||||||
val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
|
val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
|
||||||
val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
|
val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
|
||||||
|
|
||||||
val standardScaler1 = new StandardScaler()
|
val standardScaler1 = new StandardScaler()
|
||||||
.setInputCol("features")
|
.setInputCol("features")
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
|
|
||||||
object StopWordsRemoverSuite extends SparkFunSuite {
|
object StopWordsRemoverSuite extends SparkFunSuite {
|
||||||
def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
|
def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
|
||||||
|
@ -42,7 +42,7 @@ class StopWordsRemoverSuite
|
||||||
val remover = new StopWordsRemover()
|
val remover = new StopWordsRemover()
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol("filtered")
|
.setOutputCol("filtered")
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("test", "test"), Seq("test", "test")),
|
(Seq("test", "test"), Seq("test", "test")),
|
||||||
(Seq("a", "b", "c", "d"), Seq("b", "c")),
|
(Seq("a", "b", "c", "d"), Seq("b", "c")),
|
||||||
(Seq("a", "the", "an"), Seq()),
|
(Seq("a", "the", "an"), Seq()),
|
||||||
|
@ -60,7 +60,7 @@ class StopWordsRemoverSuite
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol("filtered")
|
.setOutputCol("filtered")
|
||||||
.setStopWords(stopWords)
|
.setStopWords(stopWords)
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("test", "test"), Seq()),
|
(Seq("test", "test"), Seq()),
|
||||||
(Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
|
(Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
|
||||||
(Seq("a", "the", "an"), Seq()),
|
(Seq("a", "the", "an"), Seq()),
|
||||||
|
@ -77,7 +77,7 @@ class StopWordsRemoverSuite
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol("filtered")
|
.setOutputCol("filtered")
|
||||||
.setCaseSensitive(true)
|
.setCaseSensitive(true)
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("A"), Seq("A")),
|
(Seq("A"), Seq("A")),
|
||||||
(Seq("The", "the"), Seq("The"))
|
(Seq("The", "the"), Seq("The"))
|
||||||
)).toDF("raw", "expected")
|
)).toDF("raw", "expected")
|
||||||
|
@ -98,7 +98,7 @@ class StopWordsRemoverSuite
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol("filtered")
|
.setOutputCol("filtered")
|
||||||
.setStopWords(stopWords)
|
.setStopWords(stopWords)
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("acaba", "ama", "biri"), Seq()),
|
(Seq("acaba", "ama", "biri"), Seq()),
|
||||||
(Seq("hep", "her", "scala"), Seq("scala"))
|
(Seq("hep", "her", "scala"), Seq("scala"))
|
||||||
)).toDF("raw", "expected")
|
)).toDF("raw", "expected")
|
||||||
|
@ -112,7 +112,7 @@ class StopWordsRemoverSuite
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol("filtered")
|
.setOutputCol("filtered")
|
||||||
.setStopWords(stopWords.toArray)
|
.setStopWords(stopWords.toArray)
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("python", "scala", "a"), Seq("python", "scala", "a")),
|
(Seq("python", "scala", "a"), Seq("python", "scala", "a")),
|
||||||
(Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
|
(Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
|
||||||
)).toDF("raw", "expected")
|
)).toDF("raw", "expected")
|
||||||
|
@ -126,7 +126,7 @@ class StopWordsRemoverSuite
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol("filtered")
|
.setOutputCol("filtered")
|
||||||
.setStopWords(stopWords.toArray)
|
.setStopWords(stopWords.toArray)
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("python", "scala", "a"), Seq()),
|
(Seq("python", "scala", "a"), Seq()),
|
||||||
(Seq("Python", "Scala", "swift"), Seq("swift"))
|
(Seq("Python", "Scala", "swift"), Seq("swift"))
|
||||||
)).toDF("raw", "expected")
|
)).toDF("raw", "expected")
|
||||||
|
@ -148,7 +148,7 @@ class StopWordsRemoverSuite
|
||||||
val remover = new StopWordsRemover()
|
val remover = new StopWordsRemover()
|
||||||
.setInputCol("raw")
|
.setInputCol("raw")
|
||||||
.setOutputCol(outputCol)
|
.setOutputCol(outputCol)
|
||||||
val dataSet = sqlContext.createDataFrame(Seq(
|
val dataSet = spark.createDataFrame(Seq(
|
||||||
(Seq("The", "the", "swift"), Seq("swift"))
|
(Seq("The", "the", "swift"), Seq("swift"))
|
||||||
)).toDF("raw", outputCol)
|
)).toDF("raw", outputCol)
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class StringIndexerSuite
|
||||||
|
|
||||||
test("StringIndexer") {
|
test("StringIndexer") {
|
||||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = spark.createDataFrame(data).toDF("id", "label")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
|
@ -63,8 +63,8 @@ class StringIndexerSuite
|
||||||
test("StringIndexerUnseen") {
|
test("StringIndexerUnseen") {
|
||||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
|
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
|
||||||
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
|
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = spark.createDataFrame(data).toDF("id", "label")
|
||||||
val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
|
val df2 = spark.createDataFrame(data2).toDF("id", "label")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
|
@ -93,7 +93,7 @@ class StringIndexerSuite
|
||||||
|
|
||||||
test("StringIndexer with a numeric input column") {
|
test("StringIndexer with a numeric input column") {
|
||||||
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
|
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = spark.createDataFrame(data).toDF("id", "label")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
|
@ -114,12 +114,12 @@ class StringIndexerSuite
|
||||||
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
|
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
val df = sqlContext.range(0L, 10L).toDF()
|
val df = spark.range(0L, 10L).toDF()
|
||||||
assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
|
assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("StringIndexerModel can't overwrite output column") {
|
test("StringIndexerModel can't overwrite output column") {
|
||||||
val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
|
val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("input")
|
.setInputCol("input")
|
||||||
.setOutputCol("output")
|
.setOutputCol("output")
|
||||||
|
@ -153,7 +153,7 @@ class StringIndexerSuite
|
||||||
|
|
||||||
test("IndexToString.transform") {
|
test("IndexToString.transform") {
|
||||||
val labels = Array("a", "b", "c")
|
val labels = Array("a", "b", "c")
|
||||||
val df0 = sqlContext.createDataFrame(Seq(
|
val df0 = spark.createDataFrame(Seq(
|
||||||
(0, "a"), (1, "b"), (2, "c"), (0, "a")
|
(0, "a"), (1, "b"), (2, "c"), (0, "a")
|
||||||
)).toDF("index", "expected")
|
)).toDF("index", "expected")
|
||||||
|
|
||||||
|
@ -180,7 +180,7 @@ class StringIndexerSuite
|
||||||
|
|
||||||
test("StringIndexer, IndexToString are inverses") {
|
test("StringIndexer, IndexToString are inverses") {
|
||||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = spark.createDataFrame(data).toDF("id", "label")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
|
@ -213,7 +213,7 @@ class StringIndexerSuite
|
||||||
|
|
||||||
test("StringIndexer metadata") {
|
test("StringIndexer metadata") {
|
||||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = spark.createDataFrame(data).toDF("id", "label")
|
||||||
val indexer = new StringIndexer()
|
val indexer = new StringIndexer()
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
|
|
|
@ -57,13 +57,13 @@ class RegexTokenizerSuite
|
||||||
.setPattern("\\w+|\\p{Punct}")
|
.setPattern("\\w+|\\p{Punct}")
|
||||||
.setInputCol("rawText")
|
.setInputCol("rawText")
|
||||||
.setOutputCol("tokens")
|
.setOutputCol("tokens")
|
||||||
val dataset0 = sqlContext.createDataFrame(Seq(
|
val dataset0 = spark.createDataFrame(Seq(
|
||||||
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
|
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
|
||||||
TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
|
TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
|
||||||
))
|
))
|
||||||
testRegexTokenizer(tokenizer0, dataset0)
|
testRegexTokenizer(tokenizer0, dataset0)
|
||||||
|
|
||||||
val dataset1 = sqlContext.createDataFrame(Seq(
|
val dataset1 = spark.createDataFrame(Seq(
|
||||||
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
|
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
|
||||||
TokenizerTestData("Te,st. punct", Array("punct"))
|
TokenizerTestData("Te,st. punct", Array("punct"))
|
||||||
))
|
))
|
||||||
|
@ -73,7 +73,7 @@ class RegexTokenizerSuite
|
||||||
val tokenizer2 = new RegexTokenizer()
|
val tokenizer2 = new RegexTokenizer()
|
||||||
.setInputCol("rawText")
|
.setInputCol("rawText")
|
||||||
.setOutputCol("tokens")
|
.setOutputCol("tokens")
|
||||||
val dataset2 = sqlContext.createDataFrame(Seq(
|
val dataset2 = spark.createDataFrame(Seq(
|
||||||
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
|
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
|
||||||
TokenizerTestData("Te,st. punct", Array("te,st.", "punct"))
|
TokenizerTestData("Te,st. punct", Array("te,st.", "punct"))
|
||||||
))
|
))
|
||||||
|
@ -85,7 +85,7 @@ class RegexTokenizerSuite
|
||||||
.setInputCol("rawText")
|
.setInputCol("rawText")
|
||||||
.setOutputCol("tokens")
|
.setOutputCol("tokens")
|
||||||
.setToLowercase(false)
|
.setToLowercase(false)
|
||||||
val dataset = sqlContext.createDataFrame(Seq(
|
val dataset = spark.createDataFrame(Seq(
|
||||||
TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
|
TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
|
||||||
TokenizerTestData("java scala", Array("java", "scala"))
|
TokenizerTestData("java scala", Array("java", "scala"))
|
||||||
))
|
))
|
||||||
|
|
|
@ -57,7 +57,7 @@ class VectorAssemblerSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("VectorAssembler") {
|
test("VectorAssembler") {
|
||||||
val df = sqlContext.createDataFrame(Seq(
|
val df = spark.createDataFrame(Seq(
|
||||||
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
|
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
|
||||||
)).toDF("id", "x", "y", "name", "z", "n")
|
)).toDF("id", "x", "y", "name", "z", "n")
|
||||||
val assembler = new VectorAssembler()
|
val assembler = new VectorAssembler()
|
||||||
|
@ -70,7 +70,7 @@ class VectorAssemblerSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("transform should throw an exception in case of unsupported type") {
|
test("transform should throw an exception in case of unsupported type") {
|
||||||
val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
|
val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
|
||||||
val assembler = new VectorAssembler()
|
val assembler = new VectorAssembler()
|
||||||
.setInputCols(Array("a", "b", "c"))
|
.setInputCols(Array("a", "b", "c"))
|
||||||
.setOutputCol("features")
|
.setOutputCol("features")
|
||||||
|
@ -87,7 +87,7 @@ class VectorAssemblerSuite
|
||||||
NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
|
NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
|
||||||
NumericAttribute.defaultAttr.withName("salary")))
|
NumericAttribute.defaultAttr.withName("salary")))
|
||||||
val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
|
val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
|
||||||
val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
|
val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
|
||||||
.select(
|
.select(
|
||||||
col("browser").as("browser", browser.toMetadata()),
|
col("browser").as("browser", browser.toMetadata()),
|
||||||
col("hour").as("hour", hour.toMetadata()),
|
col("hour").as("hour", hour.toMetadata()),
|
||||||
|
|
|
@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
checkPair(densePoints1Seq, sparsePoints1Seq)
|
checkPair(densePoints1Seq, sparsePoints1Seq)
|
||||||
checkPair(densePoints2Seq, sparsePoints2Seq)
|
checkPair(densePoints2Seq, sparsePoints2Seq)
|
||||||
|
|
||||||
densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
|
densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
|
||||||
sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
|
sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
|
||||||
densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
|
densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
|
||||||
sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
|
sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
|
||||||
badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
|
badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
|
||||||
}
|
}
|
||||||
|
|
||||||
private def getIndexer: VectorIndexer =
|
private def getIndexer: VectorIndexer =
|
||||||
|
@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Cannot fit an empty DataFrame") {
|
test("Cannot fit an empty DataFrame") {
|
||||||
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
|
val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
|
||||||
val vectorIndexer = getIndexer
|
val vectorIndexer = getIndexer
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
vectorIndexer.fit(rdd)
|
vectorIndexer.fit(rdd)
|
||||||
|
|
|
@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
|
||||||
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
|
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
|
||||||
|
|
||||||
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
|
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
|
||||||
val df = sqlContext.createDataFrame(rdd,
|
val df = spark.createDataFrame(rdd,
|
||||||
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
|
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
|
||||||
|
|
||||||
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
|
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
|
||||||
|
|
|
@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("Word2Vec") {
|
test("Word2Vec") {
|
||||||
|
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
|
|
||||||
val sentence = "a b " * 100 + "a c " * 10
|
val sentence = "a b " * 100 + "a c " * 10
|
||||||
val numOfWords = sentence.split(" ").size
|
val numOfWords = sentence.split(" ").size
|
||||||
|
@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("getVectors") {
|
test("getVectors") {
|
||||||
|
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
|
|
||||||
val sentence = "a b " * 100 + "a c " * 10
|
val sentence = "a b " * 100 + "a c " * 10
|
||||||
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
|
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
|
||||||
|
@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("findSynonyms") {
|
test("findSynonyms") {
|
||||||
|
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
|
|
||||||
val sentence = "a b " * 100 + "a c " * 10
|
val sentence = "a b " * 100 + "a c " * 10
|
||||||
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
|
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
|
||||||
|
@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
|
|
||||||
test("window size") {
|
test("window size") {
|
||||||
|
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
|
|
||||||
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
|
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
|
||||||
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
|
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
|
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
|
||||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
@ -305,8 +305,8 @@ class ALSSuite
|
||||||
numUserBlocks: Int = 2,
|
numUserBlocks: Int = 2,
|
||||||
numItemBlocks: Int = 3,
|
numItemBlocks: Int = 3,
|
||||||
targetRMSE: Double = 0.05): Unit = {
|
targetRMSE: Double = 0.05): Unit = {
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
val als = new ALS()
|
val als = new ALS()
|
||||||
.setRank(rank)
|
.setRank(rank)
|
||||||
.setRegParam(regParam)
|
.setRegParam(regParam)
|
||||||
|
@ -460,8 +460,8 @@ class ALSSuite
|
||||||
allEstimatorParamSettings.foreach { case (p, v) =>
|
allEstimatorParamSettings.foreach { case (p, v) =>
|
||||||
als.set(als.getParam(p), v)
|
als.set(als.getParam(p), v)
|
||||||
}
|
}
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
val model = als.fit(ratings.toDF())
|
val model = als.fit(ratings.toDF())
|
||||||
|
|
||||||
// Test Estimator save/load
|
// Test Estimator save/load
|
||||||
|
@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite {
|
||||||
// Generate test data
|
// Generate test data
|
||||||
val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0)
|
val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0)
|
||||||
// Implicitly test the cleaning of parents during ALS training
|
// Implicitly test the cleaning of parents during ALS training
|
||||||
val sqlContext = new SQLContext(sc)
|
val spark = SparkSession.builder
|
||||||
import sqlContext.implicits._
|
.master("local[2]")
|
||||||
|
.appName("ALSCleanerSuite")
|
||||||
|
.getOrCreate()
|
||||||
|
import spark.implicits._
|
||||||
val als = new ALS()
|
val als = new ALS()
|
||||||
.setRank(1)
|
.setRank(1)
|
||||||
.setRegParam(1e-5)
|
.setRegParam(1e-5)
|
||||||
|
@ -577,8 +580,8 @@ class ALSStorageSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
test("default and non-default storage params set correct RDD StorageLevels") {
|
test("default and non-default storage params set correct RDD StorageLevels") {
|
||||||
val sqlContext = this.sqlContext
|
val spark = this.spark
|
||||||
import sqlContext.implicits._
|
import spark.implicits._
|
||||||
val data = Seq(
|
val data = Seq(
|
||||||
(0, 0, 1.0),
|
(0, 0, 1.0),
|
||||||
(0, 1, 2.0),
|
(0, 1, 2.0),
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue