[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites

## What changes were proposed in this pull request?
Use SparkSession instead of SQLContext in Scala/Java TestSuites
as this PR already very big working Python TestSuites in a diff PR.

## How was this patch tested?
Existing tests

Author: Sandeep Singh <sandeep@techaddict.me>

Closes #12907 from techaddict/SPARK-15037.
This commit is contained in:
Sandeep Singh 2016-05-10 11:17:47 -07:00 committed by Andrew Or
parent bcfee153b1
commit ed0b4070fb
224 changed files with 2916 additions and 2593 deletions

View file

@ -17,18 +17,18 @@
package org.apache.spark.ml;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
/**
@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene
*/
public class JavaPipelineSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaPipelineSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaPipelineSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = jsql.createDataFrame(points, LabeledPoint.class);
dataset = spark.createDataFrame(points, LabeledPoint.class);
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -63,10 +66,10 @@ public class JavaPipelineSuite {
LogisticRegression lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {scaler, lr});
.setStages(new PipelineStage[]{scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}

View file

@ -17,8 +17,8 @@
package org.apache.spark.ml.attribute;
import org.junit.Test;
import org.junit.Assert;
import org.junit.Test;
public class JavaAttributeSuite {

View file

@ -21,8 +21,6 @@ import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeClassifierSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
dt.setImpurity(impurity);
}
DecisionTreeClassificationModel model = dt.fit(dataFrame);

View file

@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaGBTClassifierSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaGBTClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable {
.setMaxIter(3)
.setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String lossType: GBTClassifier.supportedLossTypes()) {
for (String lossType : GBTClassifier.supportedLossTypes()) {
rf.setLossType(lossType);
}
GBTClassificationModel model = rf.fit(dataFrame);

View file

@ -27,18 +27,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps);
@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r : predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
for (Row r: predNotAllZero.collectAsList()) {
for (Row r : predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
}
Assert.assertTrue(foundNonZero);
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
Assert.assertEquals(5, parent2.getMaxIter());
Assert.assertEquals(0.1, parent2.getRegParam(), eps);
@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
for (Row row: trans1.collectAsList()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
for (Row row : trans1.collectAsList()) {
Vector raw = (Vector) row.get(0);
Vector prob = (Vector) row.get(1);
Assert.assertEquals(raw.size(), 2);
Assert.assertEquals(prob.size(), 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
for (Row row: trans2.collectAsList()) {
Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
for (Row row : trans2.collectAsList()) {
double pred = row.getDouble(0);
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
Vector prob = (Vector) row.get(1);
double probOfPred = prob.apply((int) pred);
for (int i = 0; i < prob.size(); ++i) {
Assert.assertTrue(probOfPred >= prob.apply(i));
}

View file

@ -26,49 +26,49 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
sqlContext = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
sqlContext = null;
spark.stop();
spark = null;
}
@Test
public void testMLPC() {
Dataset<Row> dataFrame = sqlContext.createDataFrame(
jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
LabeledPoint.class);
List<LabeledPoint> data = Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
);
Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
.setLayers(new int[] {2, 5, 2})
.setLayers(new int[]{2, 5, 2})
.setBlockSize(1)
.setSeed(123L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame);
List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
for (Row r: predictionAndLabels) {
for (Row r : predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}
}

View file

@ -26,13 +26,12 @@ import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
public void validatePrediction(Dataset<Row> predictionAndLabels) {
@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> dataset = jsql.createDataFrame(data, schema);
Dataset<Row> dataset = spark.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);

View file

@ -20,7 +20,6 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
import org.apache.spark.sql.Row;
import scala.collection.JavaConverters;
import org.junit.After;
@ -30,56 +29,61 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
public class JavaOneVsRestSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
jsql = new SQLContext(jsc);
int nPoints = 3;
@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLOneVsRestSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
// The following coefficients and xMean/xVariance are computed from iris dataset with
// lambda=0.2.
// As a result, we are drawing samples from probability distribution of an actual model.
double[] coefficients = {
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
int nPoints = 3;
double[] xMean = {5.843, 3.057, 3.758, 1.199};
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
).asJava();
datasetRDD = jsc.parallelize(points, 2);
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
}
// The following coefficients and xMean/xVariance are computed from iris dataset with
// lambda=0.2.
// As a result, we are drawing samples from probability distribution of an actual model.
double[] coefficients = {
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
@After
public void tearDown() {
jsc.stop();
jsc = null;
}
double[] xMean = {5.843, 3.057, 3.758, 1.199};
double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)
).asJava();
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
}
@Test
public void oneVsRestDefaultParams() {
OneVsRest ova = new OneVsRest();
ova.setClassifier(new LogisticRegression());
Assert.assertEquals(ova.getLabelCol() , "label");
Assert.assertEquals(ova.getPredictionCol() , "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
}
@After
public void tearDown() {
spark.stop();
spark = null;
}
@Test
public void oneVsRestDefaultParams() {
OneVsRest ova = new OneVsRest();
ova.setClassifier(new LogisticRegression());
Assert.assertEquals(ova.getLabelCol(), "label");
Assert.assertEquals(ova.getPredictionCol(), "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol(), "prediction");
}
}

View file

@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaRandomForestClassifierSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaRandomForestClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable {
.setSeed(1234)
.setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: RandomForestClassifier.supportedImpurities()) {
for (String impurity : RandomForestClassifier.supportedImpurities()) {
rf.setImpurity(impurity);
}
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
for (String strategy: realStrategies) {
for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
for (String strategy: integerStrategies) {
for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
for (String strategy: invalidStrategies) {
for (String strategy : invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");

View file

@ -21,37 +21,37 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
private transient JavaSparkContext sc;
private transient Dataset<Row> dataset;
private transient SQLContext sql;
private transient SparkSession spark;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaKMeansSuite");
sql = new SQLContext(sc);
dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
spark = SparkSession.builder()
.master("local")
.appName("JavaKMeansSuite")
.getOrCreate();
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable {
Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
for (String column: expectedColumns) {
for (String column : expectedColumns) {
assertTrue(columns.contains(column));
}
}

View file

@ -25,40 +25,40 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaBucketizerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaBucketizerSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5};
StructType schema = new StructType(new StructField[] {
StructType schema = new StructType(new StructField[]{
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> dataset = jsql.createDataFrame(
Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),

View file

@ -21,43 +21,44 @@ import java.util.Arrays;
import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaDCTSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaDCTSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaDCTSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
Dataset<Row> dataset = jsql.createDataFrame(
double[] input = new double[]{1D, 2D, 3D, 4D};
Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())

View file

@ -25,12 +25,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaHashingTFSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaHashingTFSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaHashingTFSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -65,7 +65,7 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");

View file

@ -23,27 +23,30 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaNormalizerSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaNormalizerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -54,7 +57,7 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
));
Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normFeatures");

View file

@ -28,31 +28,34 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaPCASuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaPCASuite");
sqlContext = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
public static class VectorPair implements Serializable {
@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable {
}
);
Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
Dataset<Row> df = spark.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")

View file

@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class JavaPolynomialExpansionSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaPolynomialExpansionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite {
)
);
StructType schema = new StructType(new StructField[] {
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
new StructField("expected", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> dataset = jsql.createDataFrame(data, schema);
Dataset<Row> dataset = spark.createDataFrame(data, schema);
List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected")
.collectAsList();
for (Row r : pairs) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
double[] expected = ((Vector)r.get(1)).toArray();
double[] polyFeatures = ((Vector) r.get(0)).toArray();
double[] expected = ((Vector) r.get(1)).toArray();
Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
}
}

View file

@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaStandardScalerSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaStandardScalerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -54,7 +57,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
);
Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler()
.setInputCol("features")

View file

@ -24,11 +24,10 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType;
public class JavaStopWordsRemoverSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaStopWordsRemoverSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite {
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
);
StructType schema = new StructType(new StructField[] {
StructType schema = new StructType(new StructField[]{
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
Metadata.empty())
Metadata.empty())
});
Dataset<Row> dataset = jsql.createDataFrame(data, schema);
Dataset<Row> dataset = spark.createDataFrame(data, schema);
remover.transform(dataset).collect();
}

View file

@ -25,40 +25,42 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaStringIndexerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
sqlContext = new SQLContext(jsc);
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local");
sparkConf.setAppName("JavaStringIndexerSuite");
spark = SparkSession.builder().config(sparkConf).getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
sqlContext = null;
spark.stop();
spark = null;
}
@Test
public void testStringIndexer() {
StructType schema = createStructType(new StructField[] {
StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
Dataset<Row> dataset = spark.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
@ -70,7 +72,9 @@ public class JavaStringIndexerSuite {
output.orderBy("id").select("id", "labelIndex").collectAsList());
}
/** An alias for RowFactory.create. */
/**
* An alias for RowFactory.create.
*/
private Row cr(Object... values) {
return RowFactory.create(values);
}

View file

@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaTokenizerSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaTokenizerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -59,10 +62,10 @@ public class JavaTokenizerSuite {
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"})
));
Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class);
List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")

View file

@ -24,36 +24,39 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaVectorAssemblerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
sqlContext = new SQLContext(jsc);
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local");
sparkConf.setAppName("JavaVectorAssemblerSuite");
spark = SparkSession.builder().config(sparkConf).getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
public void testVectorAssembler() {
StructType schema = createStructType(new StructField[] {
StructType schema = createStructType(new StructField[]{
createStructField("id", IntegerType, false),
createStructField("x", DoubleType, false),
createStructField("y", new VectorUDT(), false),
@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite {
});
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setInputCols(new String[]{"x", "y", "z", "n"})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals(
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0));
}
}

View file

@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
public class JavaVectorIndexerSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaVectorIndexerSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 3.0)),
new FeatureData(Vectors.dense(1.0, 4.0))
);
SQLContext sqlContext = new SQLContext(sc);
Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexed")

View file

@ -25,7 +25,6 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
public class JavaVectorSlicerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaVectorSlicerSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite {
);
Dataset<Row> dataset =
jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
spark.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");

View file

@ -24,28 +24,28 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
public class JavaWord2VecSuite {
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
private transient SparkSession spark;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
sqlContext = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaWord2VecSuite")
.getOrCreate();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
Dataset<Row> documentDF = sqlContext.createDataFrame(
Dataset<Row> documentDF = spark.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@ -68,8 +68,8 @@ public class JavaWord2VecSuite {
Word2VecModel model = word2Vec.fit(documentDF);
Dataset<Row> result = model.transform(documentDF);
for (Row r: result.select("result").collectAsList()) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
for (Row r : result.select("result").collectAsList()) {
double[] polyFeatures = ((Vector) r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3);
}
}

View file

@ -25,23 +25,29 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
/**
* Test Param and related classes in Java
*/
public class JavaParamsSuite {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaParamsSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaParamsSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
}
@Test
@ -51,7 +57,7 @@ public class JavaParamsSuite {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0);
}
@Test

View file

@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams {
}
private IntParam myIntParam_;
public IntParam myIntParam() { return myIntParam_; }
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
public IntParam myIntParam() {
return myIntParam_;
}
public int getMyIntParam() {
return (Integer) getOrDefault(myIntParam_);
}
public JavaTestParams setMyIntParam(int value) {
set(myIntParam_, value);
@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams {
}
private DoubleParam myDoubleParam_;
public DoubleParam myDoubleParam() { return myDoubleParam_; }
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
public DoubleParam myDoubleParam() {
return myDoubleParam_;
}
public double getMyDoubleParam() {
return (Double) getOrDefault(myDoubleParam_);
}
public JavaTestParams setMyDoubleParam(double value) {
set(myDoubleParam_, value);
@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams {
}
private Param<String> myStringParam_;
public Param<String> myStringParam() { return myStringParam_; }
public String getMyStringParam() { return getOrDefault(myStringParam_); }
public Param<String> myStringParam() {
return myStringParam_;
}
public String getMyStringParam() {
return getOrDefault(myStringParam_);
}
public JavaTestParams setMyStringParam(String value) {
set(myStringParam_, value);
@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams {
}
private DoubleArrayParam myDoubleArrayParam_;
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
public DoubleArrayParam myDoubleArrayParam() {
return myDoubleArrayParam_;
}
public double[] getMyDoubleArrayParam() {
return getOrDefault(myDoubleArrayParam_);
}
public JavaTestParams setMyDoubleArrayParam(double[] value) {
set(myDoubleArrayParam_, value);
@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams {
setDefault(myIntParam(), 1);
setDefault(myDoubleParam(), 0.5);
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0});
}
@Override

View file

@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeRegressorSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeRegressorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
dt.setImpurity(impurity);
}
DecisionTreeRegressionModel model = dt.fit(dataFrame);

View file

@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaGBTRegressorSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaGBTRegressorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable {
.setMaxIter(3)
.setStepSize(0.1)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String lossType: GBTRegressor.supportedLossTypes()) {
for (String lossType : GBTRegressor.supportedLossTypes()) {
rf.setLossType(lossType);
}
GBTRegressionModel model = rf.fit(dataFrame);

View file

@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
.generateLogisticInputAsList;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLinearRegressionSuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaLinearRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assertEquals("features", model.getFeaturesCol());
@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable {
public void linearRegressionWithSetters() {
// Set params, train, and check as many params as we can.
LinearRegression lr = new LinearRegression()
.setMaxIter(10)
.setRegParam(1.0).setSolver("l-bfgs");
.setMaxIter(10)
.setRegParam(1.0).setSolver("l-bfgs");
LinearRegressionModel model = lr.fit(dataset);
LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter());
@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
LinearRegression parent2 = (LinearRegression) model2.parent();
assertEquals(5, parent2.getMaxIter());
assertEquals(0.1, parent2.getRegParam(), 0.0);

View file

@ -28,27 +28,33 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class JavaRandomForestRegressorSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaRandomForestRegressorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
JavaRDD<LabeledPoint> data = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable {
.setSeed(1234)
.setNumTrees(3)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
for (String impurity: RandomForestRegressor.supportedImpurities()) {
for (String impurity : RandomForestRegressor.supportedImpurities()) {
rf.setImpurity(impurity);
}
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
for (String strategy: realStrategies) {
for (String strategy : realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
for (String strategy: integerStrategies) {
for (String strategy : integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
for (String strategy: invalidStrategies) {
for (String strategy : invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");

View file

@ -28,12 +28,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
@ -41,16 +40,17 @@ import org.apache.spark.util.Utils;
* Test LibSVMRelation in Java.
*/
public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
private transient SparkSession spark;
private File tempDir;
private String path;
@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
sqlContext = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaLibSVMRelationSuite")
.getOrCreate();
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
File file = new File(tempDir, "part-00000");
@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite {
@After
public void tearDown() {
jsc.stop();
jsc = null;
spark.stop();
spark = null;
Utils.deleteRecursively(tempDir);
}
@Test
public void verifyLibSVMDF() {
Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense")
.load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);

View file

@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaCrossValidatorSuite implements Serializable {
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient Dataset<Row> dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
jsql = new SQLContext(jsc);
spark = SparkSession.builder()
.master("local")
.appName("JavaCrossValidatorSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After
@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable {
public void crossValidationWithLogisticRegression() {
LogisticRegression lr = new LogisticRegression();
ParamMap[] lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[] {0.001, 1000.0})
.addGrid(lr.maxIter(), new int[] {0, 10})
.addGrid(lr.regParam(), new double[]{0.001, 1000.0})
.addGrid(lr.maxIter(), new int[]{0, 10})
.build();
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
CrossValidator cv = new CrossValidator()

View file

@ -37,4 +37,5 @@ object IdentifiableSuite {
class Test(override val uid: String) extends Identifiable {
def this() = this(Identifiable.randomUID("test"))
}
}

View file

@ -27,31 +27,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaDefaultReadWriteSuite {
JavaSparkContext jsc = null;
SQLContext sqlContext = null;
SparkSession spark = null;
File tempDir = null;
@Before
public void setUp() {
jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
SQLContext.clearActive();
sqlContext = new SQLContext(jsc);
SQLContext.setActive(sqlContext);
spark = SparkSession.builder()
.master("local[2]")
.appName("JavaDefaultReadWriteSuite")
.getOrCreate();
SQLContext.setActive(spark.wrapped());
tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
}
@After
public void tearDown() {
sqlContext = null;
SQLContext.clearActive();
if (jsc != null) {
jsc.stop();
jsc = null;
if (spark != null) {
spark.stop();
spark = null;
}
Utils.deleteRecursively(tempDir);
}
@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite {
} catch (IOException e) {
// expected
}
instance.write().context(sqlContext).overwrite().save(outputPath);
instance.write().context(spark.wrapped()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",

View file

@ -27,26 +27,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
int numAccurate = 0;
for (LabeledPoint point: validationData) {
for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numAccurate++;
@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 2.0;
double B = -1.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
.setRegParam(1.0)
.setNumIterations(100);
LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
double A = 0.0;
double B = -2.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionModel model = LogisticRegressionWithSGD.train(
testRDD.rdd(), 100, 1.0, 1.0);
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);

View file

@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaNaiveBayesSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
private static final List<LabeledPoint> POINTS = Arrays.asList(
@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable {
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
int correct = 0;
for (LabeledPoint p: points) {
for (LabeledPoint p : points) {
if (model.predict(p.features()) == p.label()) {
correct += 1;
}
@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void runUsingConstructor() {
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
NaiveBayesModel model = nb.run(testRDD.rdd());
@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void runUsingStaticMethods() {
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
int numAccurate1 = validatePrediction(POINTS, model1);
@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testPredictJavaRDD() {
JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
@Override
public Vector call(LabeledPoint v) throws Exception {
return v.features();
}});
}
});
JavaRDD<Double> predictions = model.predict(vectors);
// Should be able to get the first prediction.
predictions.first();

View file

@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
public class JavaSVMSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaSVMSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaSVMSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
int numAccurate = 0;
for (LabeledPoint point: validationData) {
for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numAccurate++;
@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
weights, nPoints, 42), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD();
svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
.setRegParam(1.0)
.setNumIterations(100);
SVMModel model = svmSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
weights, nPoints, 42), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
weights, nPoints, 42), 2).cache();
List<LabeledPoint> validationData =
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);

View file

@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering;
import java.io.Serializable;
import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaBisectingKMeansSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", this.getClass().getSimpleName());
spark = SparkSession.builder()
.master("local")
.appName("JavaBisectingKMeansSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
public void twoDimensionalData() {
JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList(
Vectors.dense(4, -1),
Vectors.dense(4, 1),
Vectors.sparse(2, new int[] {0}, new double[] {1.0})
Vectors.sparse(2, new int[]{0}, new double[]{1.0})
), 2);
BisectingKMeans bkm = new BisectingKMeans()
@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable {
.setSeed(1L);
BisectingKMeansModel model = bkm.run(points);
Assert.assertEquals(3, model.k());
Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12);
for (ClusteringTreeNode child: model.root().children()) {
Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12);
for (ClusteringTreeNode child : model.root().children()) {
double[] center = child.center().toArray();
if (center[0] > 2) {
Assert.assertEquals(2, child.size());
Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
} else {
Assert.assertEquals(1, child.size());
Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
}
}
}

View file

@ -21,29 +21,35 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaGaussianMixtureSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaGaussianMixture");
spark = SparkSession.builder()
.master("local")
.appName("JavaGaussianMixture")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable {
Vectors.dense(1.0, 4.0, 6.0)
);
JavaRDD<Vector> data = sc.parallelize(points, 2);
JavaRDD<Vector> data = jsc.parallelize(points, 2);
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
.run(data);
assertEquals(model.gaussians().length, 2);

View file

@ -21,28 +21,35 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaKMeansSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaKMeans");
spark = SparkSession.builder()
.master("local")
.appName("JavaKMeans")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
JavaRDD<Vector> data = sc.parallelize(points, 2);
JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
JavaRDD<Vector> data = sc.parallelize(points, 2);
JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable {
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
);
JavaRDD<Vector> data = sc.parallelize(points, 2);
JavaRDD<Vector> data = jsc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
JavaRDD<Integer> predictions = model.predict(data);
// Should be able to get the first prediction.

View file

@ -27,37 +27,42 @@ import scala.Tuple3;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
public class JavaLDASuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLDA");
spark = SparkSession.builder()
.master("local")
.appName("JavaLDASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(),
LDASuite.tinyCorpus()[i]._2()));
tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
LDASuite.tinyCorpus()[i]._2()));
}
JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable {
.setMaxIterations(5)
.setSeed(12345);
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable {
public Boolean call(Tuple2<Long, Vector> tuple2) {
return Vectors.norm(tuple2._2(), 1.0) != 0.0;
}
});
});
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
// Check: javaTopTopicsPerDocuments
@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable {
@Test
public void localLdaMethods() {
JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2);
JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
// check: topicDistributions
@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable {
// check: logLikelihood.
ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
double logLikelihood = toyModel.logLikelihood(single);
}
@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable {
private static int tinyVocabSize = LDASuite.tinyVocabSize();
private static Matrix tinyTopics = LDASuite.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite.tinyTopicDescription();
LDASuite.tinyTopicDescription();
private JavaPairRDD<Long, Vector> corpus;
private LocalLDAModel toyModel = LDASuite.toyModel();
private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();

View file

@ -27,8 +27,6 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.apache.spark.streaming.JavaTestUtils.*;
import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStreamingKMeansSuite implements Serializable {

View file

@ -31,27 +31,34 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaRankingMetricsSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
predictionAndLabels = sc.parallelize(Arrays.asList(
spark = SparkSession.builder()
.master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
predictionAndLabels = jsc.parallelize(Arrays.asList(
Tuple2$.MODULE$.apply(
Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
Tuple2$.MODULE$.apply(
Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Tuple2$.MODULE$.apply(
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test

View file

@ -29,19 +29,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.SparkSession;
public class JavaTfIdfSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaTfIdfSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked")
JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2);
@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
for (Vector v: localTfIdfs) {
for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
@SuppressWarnings("unchecked")
JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
Arrays.asList("this is a sentence".split(" ")),
Arrays.asList("this is another sentence".split(" ")),
Arrays.asList("this is still a sentence".split(" "))), 2);
@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable {
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
for (Vector v: localTfIdfs) {
for (Vector v : localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}

View file

@ -21,9 +21,10 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import com.google.common.base.Strings;
import scala.Tuple2;
import com.google.common.base.Strings;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@ -31,19 +32,25 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaWord2VecSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaWord2VecSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaPCASuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable {
String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
List<String> words = Arrays.asList(sentence.split(" "));
List<List<String>> localDoc = Arrays.asList(words, words);
JavaRDD<List<String>> doc = sc.parallelize(localDoc);
JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
Word2Vec word2vec = new Word2Vec()
.setVectorSize(10)
.setSeed(42L);

View file

@ -26,32 +26,37 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
import org.apache.spark.sql.SparkSession;
public class JavaAssociationRulesSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth");
spark = SparkSession.builder()
.master("local")
.appName("JavaAssociationRulesSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
public void runAssociationRules() {
@SuppressWarnings("unchecked")
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Arrays.asList(
new FreqItemset<String>(new String[] {"a"}, 15L),
new FreqItemset<String>(new String[] {"b"}, 35L),
new FreqItemset<String>(new String[] {"a", "b"}, 12L)
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = jsc.parallelize(Arrays.asList(
new FreqItemset<String>(new String[]{"a"}, 15L),
new FreqItemset<String>(new String[]{"b"}, 35L),
new FreqItemset<String>(new String[]{"a", "b"}, 12L)
));
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
}
}

View file

@ -22,34 +22,41 @@ import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth");
spark = SparkSession.builder()
.master("local")
.appName("JavaFPGrowth")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
public void runFPGrowth() {
@SuppressWarnings("unchecked")
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable {
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
assertEquals(18, freqItemsets.size());
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable {
public void runFPGrowthSaveLoad() {
@SuppressWarnings("unchecked")
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable {
String outputPath = tempDir.getPath();
try {
model.save(sc.sc(), outputPath);
model.save(spark.sparkContext(), outputPath);
@SuppressWarnings("unchecked")
FPGrowthModel<String> newModel =
(FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath);
(FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath);
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
.collect();
assertEquals(18, freqItemsets.size());
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();

View file

@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaPrefixSpan");
spark = SparkSession.builder()
.master("local")
.appName("JavaPrefixSpan")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
public void runPrefixSpan() {
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite {
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite {
@Test
public void runPrefixSpanSaveLoad() {
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite {
String outputPath = tempDir.getPath();
try {
model.save(sc.sc(), outputPath);
PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
model.save(spark.sparkContext(), outputPath);
PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath);
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}

View file

@ -17,147 +17,149 @@
package org.apache.spark.mllib.linalg;
import static org.junit.Assert.*;
import org.junit.Test;
import java.io.Serializable;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
public class JavaMatricesSuite implements Serializable {
@Test
public void randMatrixConstruction() {
Random rng = new Random(24);
Matrix r = Matrices.rand(3, 4, rng);
rng.setSeed(24);
DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
@Test
public void randMatrixConstruction() {
Random rng = new Random(24);
Matrix r = Matrices.rand(3, 4, rng);
rng.setSeed(24);
DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
rng.setSeed(24);
Matrix rn = Matrices.randn(3, 4, rng);
rng.setSeed(24);
DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
rng.setSeed(24);
Matrix rn = Matrices.randn(3, 4, rng);
rng.setSeed(24);
DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
rng.setSeed(24);
Matrix s = Matrices.sprand(3, 4, 0.5, rng);
rng.setSeed(24);
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
rng.setSeed(24);
Matrix s = Matrices.sprand(3, 4, 0.5, rng);
rng.setSeed(24);
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
rng.setSeed(24);
Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
rng.setSeed(24);
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
}
rng.setSeed(24);
Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
rng.setSeed(24);
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
}
@Test
public void identityMatrixConstruction() {
Matrix r = Matrices.eye(2);
DenseMatrix dr = DenseMatrix.eye(2);
SparseMatrix sr = SparseMatrix.speye(2);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
}
@Test
public void identityMatrixConstruction() {
Matrix r = Matrices.eye(2);
DenseMatrix dr = DenseMatrix.eye(2);
SparseMatrix sr = SparseMatrix.speye(2);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
}
@Test
public void diagonalMatrixConstruction() {
Vector v = Vectors.dense(1.0, 0.0, 2.0);
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
@Test
public void diagonalMatrixConstruction() {
Vector v = Vectors.dense(1.0, 0.0, 2.0);
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
Matrix m = Matrices.diag(v);
Matrix sm = Matrices.diag(sv);
DenseMatrix d = DenseMatrix.diag(v);
DenseMatrix sd = DenseMatrix.diag(sv);
SparseMatrix s = SparseMatrix.spdiag(v);
SparseMatrix ss = SparseMatrix.spdiag(sv);
Matrix m = Matrices.diag(v);
Matrix sm = Matrices.diag(sv);
DenseMatrix d = DenseMatrix.diag(v);
DenseMatrix sd = DenseMatrix.diag(sv);
SparseMatrix s = SparseMatrix.spdiag(v);
SparseMatrix ss = SparseMatrix.spdiag(sv);
assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
assertArrayEquals(s.values(), ss.values(), 0.0);
assertEquals(2, s.values().length);
assertEquals(2, ss.values().length);
assertEquals(4, s.colPtrs().length);
assertEquals(4, ss.colPtrs().length);
}
assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
assertArrayEquals(s.values(), ss.values(), 0.0);
assertEquals(2, s.values().length);
assertEquals(2, ss.values().length);
assertEquals(4, s.colPtrs().length);
assertEquals(4, ss.colPtrs().length);
}
@Test
public void zerosMatrixConstruction() {
Matrix z = Matrices.zeros(2, 2);
Matrix one = Matrices.ones(2, 2);
DenseMatrix dz = DenseMatrix.zeros(2, 2);
DenseMatrix done = DenseMatrix.ones(2, 2);
@Test
public void zerosMatrixConstruction() {
Matrix z = Matrices.zeros(2, 2);
Matrix one = Matrices.ones(2, 2);
DenseMatrix dz = DenseMatrix.zeros(2, 2);
DenseMatrix done = DenseMatrix.ones(2, 2);
assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
}
assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
}
@Test
public void sparseDenseConversion() {
int m = 3;
int n = 2;
double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
int[] colPtrs = new int[]{0, 2, 4};
int[] rowIndices = new int[]{0, 1, 1, 2};
@Test
public void sparseDenseConversion() {
int m = 3;
int n = 2;
double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
int[] colPtrs = new int[]{0, 2, 4};
int[] rowIndices = new int[]{0, 1, 1, 2};
SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
SparseMatrix spMat2 = deMat1.toSparse();
DenseMatrix deMat2 = spMat1.toDense();
SparseMatrix spMat2 = deMat1.toSparse();
DenseMatrix deMat2 = spMat1.toDense();
assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
}
assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
}
@Test
public void concatenateMatrices() {
int m = 3;
int n = 2;
@Test
public void concatenateMatrices() {
int m = 3;
int n = 2;
Random rng = new Random(42);
SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
rng.setSeed(42);
DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
Matrix deMat2 = Matrices.eye(3);
Matrix spMat2 = Matrices.speye(3);
Matrix deMat3 = Matrices.eye(2);
Matrix spMat3 = Matrices.speye(2);
Random rng = new Random(42);
SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
rng.setSeed(42);
DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
Matrix deMat2 = Matrices.eye(3);
Matrix spMat2 = Matrices.speye(3);
Matrix deMat3 = Matrices.eye(2);
Matrix spMat3 = Matrices.speye(2);
Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
assertEquals(3, deHorz1.numRows());
assertEquals(3, deHorz2.numRows());
assertEquals(3, deHorz3.numRows());
assertEquals(3, spHorz.numRows());
assertEquals(5, deHorz1.numCols());
assertEquals(5, deHorz2.numCols());
assertEquals(5, deHorz3.numCols());
assertEquals(5, spHorz.numCols());
assertEquals(3, deHorz1.numRows());
assertEquals(3, deHorz2.numRows());
assertEquals(3, deHorz3.numRows());
assertEquals(3, spHorz.numRows());
assertEquals(5, deHorz1.numCols());
assertEquals(5, deHorz2.numCols());
assertEquals(5, deHorz3.numCols());
assertEquals(5, spHorz.numCols());
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
assertEquals(5, deVert1.numRows());
assertEquals(5, deVert2.numRows());
assertEquals(5, deVert3.numRows());
assertEquals(5, spVert.numRows());
assertEquals(2, deVert1.numCols());
assertEquals(2, deVert2.numCols());
assertEquals(2, deVert3.numCols());
assertEquals(2, spVert.numCols());
}
assertEquals(5, deVert1.numRows());
assertEquals(5, deVert2.numRows());
assertEquals(5, deVert3.numRows());
assertEquals(5, spVert.numRows());
assertEquals(2, deVert1.numCols());
assertEquals(2, deVert2.numCols());
assertEquals(2, deVert3.numCols());
assertEquals(2, spVert.numCols());
}
}

View file

@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg;
import java.io.Serializable;
import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals;
import scala.Tuple2;
import org.junit.Test;
import static org.junit.Assert.*;
public class JavaVectorsSuite implements Serializable {
@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable {
public void sparseArrayConstruction() {
@SuppressWarnings("unchecked")
Vector v = Vectors.sparse(3, Arrays.asList(
new Tuple2<>(0, 2.0),
new Tuple2<>(2, 3.0)));
new Tuple2<>(0, 2.0),
new Tuple2<>(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
}
}

View file

@ -20,29 +20,35 @@ package org.apache.spark.mllib.random;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.junit.Assert;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.mllib.random.RandomRDDs.*;
public class JavaRandomRDDsSuite {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRandomRDDsSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaRandomRDDsSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m);
JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p);
JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite {
long m = 1000L;
int p = 2;
long seed = 1L;
JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m);
JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p);
JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite {
@Test
public void testGammaRDD() {
double shape = 1.0;
double scale = 2.0;
double jscale = 2.0;
long m = 1000L;
int p = 2;
long seed = 1L;
JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m);
JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p);
JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed);
for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
}
}
@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n);
JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p);
JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n);
JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p);
JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n);
JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p);
JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n);
JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p);
JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite {
@SuppressWarnings("unchecked")
public void testGammaVectorRDD() {
double shape = 1.0;
double scale = 2.0;
double jscale = 2.0;
long m = 100L;
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n);
JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p);
JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite {
long seed = 1L;
int numPartitions = 0;
StringGenerator gen = new StringGenerator();
JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(size, rdd.count());
Assert.assertEquals(2, rdd.first().length());
}
@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite {
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator<String>, Serializable {
public String nextValue() {
return "42";
}
@Override
public StringGenerator copy() {
return new StringGenerator();
}
@Override
public void setSeed(long seed) {
}

View file

@ -32,40 +32,46 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaALSSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaALS");
spark = SparkSession.builder()
.master("local")
.appName("JavaALS")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
private void validatePrediction(
MatrixFactorizationModel model,
int users,
int products,
double[] trueRatings,
double matchThreshold,
boolean implicitPrefs,
double[] truePrefs) {
MatrixFactorizationModel model,
int users,
int products,
double[] trueRatings,
double matchThreshold,
boolean implicitPrefs,
double[] truePrefs) {
List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
for (int u=0; u < users; ++u) {
for (int p=0; p < products; ++p) {
for (int u = 0; u < users; ++u) {
for (int p = 0; p < products; ++p) {
localUsersProducts.add(new Tuple2<>(u, p));
}
}
JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
List<Rating> predictedRatings = model.predict(usersProducts).collect();
Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) {
for (Rating r: predictedRatings) {
for (Rating r : predictedRatings) {
double prediction = r.rating();
double correct = trueRatings[r.product() * users + r.user()];
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable {
// (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
for (Rating r: predictedRatings) {
for (Rating r : predictedRatings) {
double prediction = r.rating();
double truePref = truePrefs[r.product() * users + r.user()];
double confidence = 1.0 +
@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable {
int users = 50;
int products = 100;
Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
Tuple3<List<Rating>, double[], double[]> testData =
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
JavaRDD<Rating> data = sc.parallelize(testData._1());
JavaRDD<Rating> data = jsc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.setImplicitPrefs(true)
@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable {
int users = 200;
int products = 50;
List<Rating> testData = ALSSuite.generateRatingsAsJava(
users, products, features, 0.7, true, false)._1();
JavaRDD<Rating> data = sc.parallelize(testData);
users, products, features, 0.7, true, false)._1();
JavaRDD<Rating> data = jsc.parallelize(testData);
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.setImplicitPrefs(true)
@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable {
private static void validateRecommendations(Rating[] recommendations, int howMany) {
Assert.assertEquals(howMany, recommendations.length);
for (int i = 1; i < recommendations.length; i++) {
Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating());
Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
}
Assert.assertTrue(recommendations[0].rating() > 0.7);
}

View file

@ -32,15 +32,17 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
public class JavaIsotonicRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
for (int i = 1; i <= labels.length; i++) {
input.add(new Tuple3<>(labels[i-1], (double) i, 1.0));
input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
}
return input;
@ -48,20 +50,24 @@ public class JavaIsotonicRegressionSuite implements Serializable {
private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
sc.parallelize(generateIsotonicInput(labels), 2).cache();
jsc.parallelize(generateIsotonicInput(labels), 2).cache();
return new IsotonicRegression().run(trainRDD);
}
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaLinearRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
@Test
@ -70,7 +76,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
Assert.assertArrayEquals(
new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
}
@Test
@ -78,7 +84,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
IsotonicRegressionModel model =
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
List<Double> predictions = model.predict(testRDD).collect();
Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);

View file

@ -28,24 +28,30 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
public class JavaLassoSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLassoSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaLassoSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
int numAccurate = 0;
for (LabeledPoint point: validationData) {
for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
@ -61,15 +67,15 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(0.01)
.setNumIterations(20);
.setRegParam(0.01)
.setNumIterations(20);
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@ -82,10 +88,10 @@ public class JavaLassoSuite implements Serializable {
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);

View file

@ -25,34 +25,40 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaLinearRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
int numAccurate = 0;
for (LabeledPoint point: validationData) {
Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
numAccurate++;
}
for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
numAccurate++;
}
}
return numAccurate;
}
@ -63,10 +69,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 3.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
linSGDImpl.setIntercept(true);
@ -82,10 +88,10 @@ public class JavaLinearRegressionSuite implements Serializable {
double A = 0.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
@ -98,7 +104,7 @@ public class JavaLinearRegressionSuite implements Serializable {
int nPoints = 100;
double A = 0.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());

View file

@ -29,25 +29,31 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.sql.SparkSession;
public class JavaRidgeRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaRidgeRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
private static double predictionError(List<LabeledPoint> validationData,
RidgeRegressionModel model) {
double errorSum = 0;
for (LabeledPoint point: validationData) {
for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
errorSum += (prediction - point.label()) * (prediction - point.label());
}
@ -68,9 +74,9 @@ public class JavaRidgeRegressionSuite implements Serializable {
public void runRidgeRegressionUsingConstructor() {
int numExamples = 50;
int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
@ -94,7 +100,7 @@ public class JavaRidgeRegressionSuite implements Serializable {
int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
JavaRDD<LabeledPoint> testRDD = jsc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);

View file

@ -24,13 +24,11 @@ import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
@ -38,36 +36,42 @@ import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
import org.apache.spark.mllib.stat.test.StreamingTest;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient JavaStreamingContext ssc;
@Before
public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("JavaStatistics")
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
sc = new JavaSparkContext(conf);
ssc = new JavaStreamingContext(sc, new Duration(1000));
spark = SparkSession.builder()
.master("local[2]")
.appName("JavaStatistics")
.config(conf)
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
ssc = new JavaStreamingContext(jsc, new Duration(1000));
ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
spark.stop();
ssc.stop();
ssc = null;
sc = null;
spark = null;
}
@Test
public void testCorr() {
JavaRDD<Double> x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
JavaRDD<Double> y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
Double corr1 = Statistics.corr(x, y);
Double corr2 = Statistics.corr(x, y, "pearson");
@ -77,7 +81,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test
public void kolmogorovSmirnovTest() {
JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
data, "norm", 0.0, 1.0);
@ -85,7 +89,7 @@ public class JavaStatisticsSuite implements Serializable {
@Test
public void chiSqTest() {
JavaRDD<LabeledPoint> data = sc.parallelize(Arrays.asList(
JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));

View file

@ -35,25 +35,31 @@ import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.sql.SparkSession;
public class JavaDecisionTreeSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SparkSession spark;
private transient JavaSparkContext jsc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
sc.stop();
sc = null;
spark.stop();
spark = null;
}
int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
int numCorrect = 0;
for (LabeledPoint point: validationData) {
for (LabeledPoint point : validationData) {
Double prediction = model.predict(point.features());
if (prediction == point.label()) {
numCorrect++;
@ -65,7 +71,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test
public void runDTUsingConstructor() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@ -73,7 +79,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2;
int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
maxBins, categoricalFeaturesInfo);
maxBins, categoricalFeaturesInfo);
DecisionTree learner = new DecisionTree(strategy);
DecisionTreeModel model = learner.run(rdd.rdd());
@ -85,7 +91,7 @@ public class JavaDecisionTreeSuite implements Serializable {
@Test
public void runDTUsingStaticMethods() {
List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
@ -93,7 +99,7 @@ public class JavaDecisionTreeSuite implements Serializable {
int numClasses = 2;
int maxBins = 100;
Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
maxBins, categoricalFeaturesInfo);
maxBins, categoricalFeaturesInfo);
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);

View file

@ -183,7 +183,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("pipeline validateParams") {
val df = sqlContext.createDataFrame(
val df = spark.createDataFrame(
Seq(
(1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, Vectors.dense(1.0, 0.0, 4.0), 2.0),

View file

@ -32,7 +32,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("extractLabeledPoints") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data)
spark.createDataFrame(data)
}
val c = new MockClassifier
@ -72,7 +72,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("getNumClasses") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data)
spark.createDataFrame(data)
}
val c = new MockClassifier

View file

@ -337,13 +337,13 @@ class DecisionTreeClassifierSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
dt, isClassification = true, sqlContext) { (expected, actual) =>
dt, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val dt = new DecisionTreeClassifier().setMaxDepth(1)
dt.fit(df)
}

View file

@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
gbt, isClassification = true, sqlContext) { (expected, actual) =>
gbt, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
@ -130,7 +130,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
*/
test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
gbt.fit(df)
}
@ -138,7 +138,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("extractLabeledPoints with bad data") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data)
spark.createDataFrame(data)
}
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)

View file

@ -42,7 +42,7 @@ class LogisticRegressionSuite
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
dataset = spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
binaryDataset = {
val nPoints = 10000
@ -54,7 +54,7 @@ class LogisticRegressionSuite
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
addIntercept = true, nPoints, 42)
sqlContext.createDataFrame(sc.parallelize(testData, 4))
spark.createDataFrame(sc.parallelize(testData, 4))
}
}
@ -202,7 +202,7 @@ class LogisticRegressionSuite
}
test("logistic regression: Predictor, Classifier methods") {
val sqlContext = this.sqlContext
val spark = this.spark
val lr = new LogisticRegression
val model = lr.fit(dataset)
@ -864,8 +864,8 @@ class LogisticRegressionSuite
}
}
(sqlContext.createDataFrame(sc.parallelize(data1, 4)),
sqlContext.createDataFrame(sc.parallelize(data2, 4)))
(spark.createDataFrame(sc.parallelize(data1, 4)),
spark.createDataFrame(sc.parallelize(data2, 4)))
}
val trainer1a = (new LogisticRegression).setFitIntercept(true)
@ -938,7 +938,7 @@ class LogisticRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
lr, isClassification = true, sqlContext) { (expected, actual) =>
lr, isClassification = true, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}

View file

@ -36,7 +36,7 @@ class MultilayerPerceptronClassifierSuite
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(Seq(
dataset = spark.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
@ -77,7 +77,7 @@ class MultilayerPerceptronClassifierSuite
}
test("Test setWeights by training restart") {
val dataFrame = sqlContext.createDataFrame(Seq(
val dataFrame = spark.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
@ -113,7 +113,7 @@ class MultilayerPerceptronClassifierSuite
// the input seed is somewhat magic, to make this test pass
val rdd = sc.parallelize(generateMultinomialLogisticInput(
coefficients, xMean, xVariance, true, nPoints, 1), 2)
val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
val dataFrame = spark.createDataFrame(rdd).toDF("label", "features")
val numClasses = 3
val numIterations = 100
val layers = Array[Int](4, 5, 4, numClasses)
@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
mpc, isClassification = true, sqlContext) { (expected, actual) =>
mpc, isClassification = true, spark) { (expected, actual) =>
assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights)
}

View file

@ -43,7 +43,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
}
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
@ -127,7 +127,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
val testDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 42, "multinomial"))
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset)
@ -135,7 +135,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model)
assert(model.hasParent)
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
@ -157,7 +157,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
val testDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 45, "bernoulli"))
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset)
@ -165,7 +165,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model)
assert(model.hasParent)
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
val validationDataset = spark.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
test("should support all NumericType labels and not support other types") {
val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
nb, isClassification = true, sqlContext) { (expected, actual) =>
nb, isClassification = true, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
}

View file

@ -53,7 +53,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
rdd = sc.parallelize(generateMultinomialLogisticInput(
coefficients, xMean, xVariance, true, nPoints, 42), 2)
dataset = sqlContext.createDataFrame(rdd)
dataset = spark.createDataFrame(rdd)
}
test("params") {
@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("should support all NumericType labels and not support other types") {
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
ovr, isClassification = true, sqlContext) { (expected, actual) =>
ovr, isClassification = true, spark) { (expected, actual) =>
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
assert(expectedModels.length === actualModels.length)

View file

@ -155,7 +155,7 @@ class RandomForestClassifierSuite
}
test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc))
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
rf.fit(df)
}
@ -189,7 +189,7 @@ class RandomForestClassifierSuite
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
rf, isClassification = true, sqlContext) { (expected, actual) =>
rf, isClassification = true, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}

View file

@ -30,7 +30,7 @@ class BisectingKMeansSuite
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
}
test("default parameters") {

View file

@ -32,7 +32,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
}
test("default parameters") {

View file

@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
private[clustering] case class TestRow(features: Vector)
@ -34,7 +34,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
}
test("default parameters") {
@ -142,11 +142,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = spark.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
spark.createDataFrame(rdd)
}
/**

View file

@ -17,30 +17,30 @@
package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql._
object LDASuite {
def generateLDAData(
sql: SQLContext,
spark: SparkSession,
rows: Int,
k: Int,
vocabSize: Int): DataFrame = {
val avgWC = 1 // average instances of each word in a doc
val sc = sql.sparkContext
val sc = spark.sparkContext
val rng = new java.util.Random()
rng.setSeed(1)
val rdd = sc.parallelize(1 to rows).map { i =>
Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
}.map(v => new TestRow(v))
sql.createDataFrame(rdd)
spark.createDataFrame(rdd)
}
/**
@ -68,7 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
override def beforeAll(): Unit = {
super.beforeAll()
dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize)
dataset = LDASuite.generateLDAData(spark, 50, k, vocabSize)
}
test("default parameters") {
@ -140,7 +140,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
new LDA().setTopicConcentration(-1.1)
}
val dummyDF = sqlContext.createDataFrame(Seq(
val dummyDF = spark.createDataFrame(Seq(
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
// validate parameters
lda.transformSchema(dummyDF.schema)
@ -274,7 +274,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
// There should be 1 checkpoint remaining.
assert(model.getCheckpointFiles.length === 1)
val checkpointFile = new Path(model.getCheckpointFiles.head)
val fs = checkpointFile.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration)
assert(fs.exists(checkpointFile))
model.deleteCheckpointFiles()
assert(model.getCheckpointFiles.isEmpty)

View file

@ -42,21 +42,21 @@ class BinaryClassificationEvaluatorSuite
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR")
val vectorDF = sqlContext.createDataFrame(Seq(
val vectorDF = spark.createDataFrame(Seq(
(0d, Vectors.dense(12, 2.5)),
(1d, Vectors.dense(1, 3)),
(0d, Vectors.dense(10, 2))
)).toDF("label", "rawPrediction")
assert(evaluator.evaluate(vectorDF) === 1.0)
val doubleDF = sqlContext.createDataFrame(Seq(
val doubleDF = spark.createDataFrame(Seq(
(0d, 0d),
(1d, 1d),
(0d, 0d)
)).toDF("label", "rawPrediction")
assert(evaluator.evaluate(doubleDF) === 1.0)
val stringDF = sqlContext.createDataFrame(Seq(
val stringDF = spark.createDataFrame(Seq(
(0d, "0d"),
(1d, "1d"),
(0d, "0d")
@ -71,6 +71,6 @@ class BinaryClassificationEvaluatorSuite
test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
MLTestingUtils.checkNumericTypes(evaluator, sqlContext)
MLTestingUtils.checkNumericTypes(evaluator, spark)
}
}

View file

@ -38,6 +38,6 @@ class MulticlassClassificationEvaluatorSuite
}
test("should support all NumericType labels and not support other types") {
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext)
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark)
}
}

View file

@ -42,7 +42,7 @@ class RegressionEvaluatorSuite
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
* .saveAsTextFile("path")
*/
val dataset = sqlContext.createDataFrame(
val dataset = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
@ -85,6 +85,6 @@ class RegressionEvaluatorSuite
}
test("should support all NumericType labels and not support other types") {
MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext)
MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark)
}
}

View file

@ -39,7 +39,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
val dataFrame: DataFrame = spark.createDataFrame(
data.zip(defaultBinarized)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer()
@ -55,7 +55,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize continuous features with setter") {
val threshold: Double = 0.2
val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
val dataFrame: DataFrame = spark.createDataFrame(
data.zip(thresholdBinarized)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer()
@ -71,7 +71,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize vector of continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
val dataFrame: DataFrame = spark.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected")
@ -88,7 +88,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("Binarize vector of continuous features with setter") {
val threshold: Double = 0.2
val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
val dataFrame: DataFrame = spark.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected")

View file

@ -39,7 +39,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validData = Array(-0.5, -0.3, 0.0, 0.2)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
@ -55,13 +55,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
// Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData
val invalidData2 = Array(0.51) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer.transform(badDF1).collect()
}
}
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer.transform(badDF2).collect()
@ -74,7 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")

View file

@ -24,14 +24,17 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.{Row, SparkSession}
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
test("Test Chi-Square selector") {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val spark = SparkSession.builder
.master("local[2]")
.appName("ChiSqSelectorSuite")
.getOrCreate()
import spark.implicits._
val data = Seq(
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),

View file

@ -35,7 +35,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
private def split(s: String): Seq[String] = s.split("\\s+")
test("CountVectorizerModel common cases") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("a b b c d a"),
@ -55,7 +55,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizer common cases") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
@ -76,7 +76,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizer vocabSize and minDF") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
@ -118,7 +118,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
test("CountVectorizer throws exception when vocab is empty") {
intercept[IllegalArgumentException] {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a a b b c c")),
(1, split("aa bb cc")))
).toDF("id", "words")
@ -132,7 +132,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizerModel with minTF count") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq())),
@ -151,7 +151,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizerModel with minTF freq") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
@ -170,7 +170,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("CountVectorizerModel and CountVectorizer with binary") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, split("a a a a b b b b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),

View file

@ -63,7 +63,7 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
val expectedResult = Vectors.dense(expectedResultBuffer)
val dataset = sqlContext.createDataFrame(Seq(
val dataset = spark.createDataFrame(Seq(
DCTTestData(data, expectedResult)
))

View file

@ -34,7 +34,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("hashingTF") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, "a a b b c d".split(" ").toSeq)
)).toDF("id", "words")
val n = 100
@ -54,7 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("applying binary term freqs") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, "a a b c c c".split(" ").toSeq)
)).toDF("id", "words")
val n = 100

View file

@ -60,7 +60,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
})
val expected = scaleDataWithIDF(data, idf)
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF()
.setInputCol("features")
@ -86,7 +86,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
})
val expected = scaleDataWithIDF(data, idf)
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF()
.setInputCol("features")

View file

@ -59,7 +59,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
}
test("numeric interaction") {
val data = sqlContext.createDataFrame(
val data = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0)),
(1, Vectors.dense(1.0, 5.0)))
@ -74,7 +74,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("b").as("b", groupAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
val res = trans.transform(df)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)))
@ -90,7 +90,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
}
test("nominal interaction") {
val data = sqlContext.createDataFrame(
val data = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0)),
(1, Vectors.dense(1.0, 5.0)))
@ -106,7 +106,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("b").as("b", groupAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
val res = trans.transform(df)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)))
@ -126,7 +126,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
}
test("default attr names") {
val data = sqlContext.createDataFrame(
val data = spark.createDataFrame(
Seq(
(2, Vectors.dense(0.0, 4.0), 1.0),
(1, Vectors.dense(1.0, 5.0), 10.0))
@ -142,7 +142,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
col("c").as("c", NumericAttribute.defaultAttr.toMetadata()))
val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features")
val res = trans.transform(df)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)),
(1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)))

View file

@ -36,7 +36,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
Vectors.sparse(3, Array(0, 2), Array(-1, -1)),
Vectors.sparse(3, Array(0), Array(-0.75)))
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val scaler = new MaxAbsScaler()
.setInputCol("features")
.setOutputCol("scaled")

View file

@ -38,7 +38,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
Vectors.sparse(3, Array(0, 2), Array(5, 5)),
Vectors.sparse(3, Array(0), Array(-2.5)))
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
val scaler = new MinMaxScaler()
.setInputCol("features")
.setOutputCol("scaled")
@ -57,7 +57,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
test("MinMaxScaler arguments max must be larger than min") {
withClue("arguments max must be larger than min") {
val dummyDF = sqlContext.createDataFrame(Seq(
val dummyDF = spark.createDataFrame(Seq(
(1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
intercept[IllegalArgumentException] {
val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")

View file

@ -34,7 +34,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
val nGram = new NGram()
.setInputCol("inputTokens")
.setOutputCol("nGrams")
val dataset = sqlContext.createDataFrame(Seq(
val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array("Test", "for", "ngram", "."),
Array("Test for", "for ngram", "ngram .")
@ -47,7 +47,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens")
.setOutputCol("nGrams")
.setN(4)
val dataset = sqlContext.createDataFrame(Seq(
val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array("a b c d", "b c d e")
@ -60,7 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens")
.setOutputCol("nGrams")
.setN(4)
val dataset = sqlContext.createDataFrame(Seq(
val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array(),
Array()
@ -73,7 +73,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setInputCol("inputTokens")
.setOutputCol("nGrams")
.setN(6)
val dataset = sqlContext.createDataFrame(Seq(
val dataset = spark.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array()

View file

@ -61,7 +61,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.sparse(3, Seq())
)
dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normalized_features")

View file

@ -32,7 +32,7 @@ class OneHotEncoderSuite
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@ -81,7 +81,7 @@ class OneHotEncoderSuite
test("input column with ML attribute") {
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
.select(col("size").as("size", attr.toMetadata()))
val encoder = new OneHotEncoder()
.setInputCol("size")
@ -94,7 +94,7 @@ class OneHotEncoderSuite
}
test("input column without ML attribute") {
val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
val encoder = new OneHotEncoder()
.setInputCol("index")
.setOutputCol("encoded")

View file

@ -49,7 +49,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val pc = mat.computePrincipalComponents(3)
val expected = mat.multiply(pc).rows
val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
val pca = new PCA()
.setInputCol("features")

View file

@ -59,7 +59,7 @@ class PolynomialExpansionSuite
Vectors.sparse(19, Array.empty, Array.empty))
test("Polynomial expansion with default parameter") {
val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features")
@ -76,7 +76,7 @@ class PolynomialExpansionSuite
}
test("Polynomial expansion with setter") {
val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features")
@ -94,7 +94,7 @@ class PolynomialExpansionSuite
}
test("Polynomial expansion with degree 1 is identity on vectors") {
val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
val polynomialExpansion = new PolynomialExpansion()
.setInputCol("features")

View file

@ -32,12 +32,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("transform numeric data") {
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
@ -50,7 +50,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("features column already exists") {
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] {
formula.fit(original)
}
@ -61,7 +61,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("label column already exists") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
val model = formula.fit(original)
val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
@ -70,7 +70,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("label column already exists but is not double type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
val model = formula.fit(original)
intercept[IllegalArgumentException] {
model.transformSchema(original.schema)
@ -82,7 +82,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("allow missing label column for test datasets") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
val model = formula.fit(original)
val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
@ -91,14 +91,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("allow empty label") {
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
).toDF("id", "a", "b")
val formula = new RFormula().setFormula("~ a + b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
(4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
@ -110,13 +110,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("encodes string terms") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
@ -129,13 +129,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("index string label") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
@ -148,7 +148,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
@ -165,7 +165,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("vector attribute generation") {
val formula = new RFormula().setFormula("id ~ vec")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val model = formula.fit(original)
@ -181,7 +181,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("vector attribute generation with unnamed input attrs") {
val formula = new RFormula().setFormula("id ~ vec2")
val base = sqlContext.createDataFrame(
val base = spark.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val metadata = new AttributeGroup(
@ -203,12 +203,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, 2, 4, 2), (2, 3, 4, 1))
).toDF("a", "b", "c", "d")
val model = formula.fit(original)
val result = model.transform(original)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
(2, 3, 4, 1, Vectors.dense(12.0), 2.0))
@ -223,12 +223,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("factor numeric interaction") {
val formula = new RFormula().setFormula("id ~ a:b")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
@ -250,12 +250,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("factor factor interaction") {
val formula = new RFormula().setFormula("id ~ a:b")
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq(
(1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
@ -299,7 +299,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
val dataset = sqlContext.createDataFrame(
val dataset = spark.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b")

View file

@ -31,13 +31,13 @@ class SQLTransformerSuite
}
test("transform numeric data") {
val original = sqlContext.createDataFrame(
val original = spark.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
val sqlTrans = new SQLTransformer().setStatement(
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
val result = sqlTrans.transform(original)
val resultSchema = sqlTrans.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
val expected = spark.createDataFrame(
Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
.toDF("id", "v1", "v2", "v3", "v4")
assert(result.schema.toString == resultSchema.toString)
@ -52,7 +52,7 @@ class SQLTransformerSuite
}
test("transformSchema") {
val df = sqlContext.range(10)
val df = spark.range(10)
val outputSchema = new SQLTransformer()
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
.transformSchema(df.schema)

View file

@ -73,7 +73,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("Standardization with default parameter") {
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
val standardScaler0 = new StandardScaler()
.setInputCol("features")
@ -84,9 +84,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("Standardization with setter") {
val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected")
val standardScaler1 = new StandardScaler()
.setInputCol("features")

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{Dataset, Row}
object StopWordsRemoverSuite extends SparkFunSuite {
def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
@ -42,7 +42,7 @@ class StopWordsRemoverSuite
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("test", "test"), Seq("test", "test")),
(Seq("a", "b", "c", "d"), Seq("b", "c")),
(Seq("a", "the", "an"), Seq()),
@ -60,7 +60,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("test", "test"), Seq()),
(Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
(Seq("a", "the", "an"), Seq()),
@ -77,7 +77,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setCaseSensitive(true)
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The"))
)).toDF("raw", "expected")
@ -98,7 +98,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("acaba", "ama", "biri"), Seq()),
(Seq("hep", "her", "scala"), Seq("scala"))
)).toDF("raw", "expected")
@ -112,7 +112,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords.toArray)
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("python", "scala", "a"), Seq("python", "scala", "a")),
(Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
)).toDF("raw", "expected")
@ -126,7 +126,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords.toArray)
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("python", "scala", "a"), Seq()),
(Seq("Python", "Scala", "swift"), Seq("swift"))
)).toDF("raw", "expected")
@ -148,7 +148,7 @@ class StopWordsRemoverSuite
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol(outputCol)
val dataSet = sqlContext.createDataFrame(Seq(
val dataSet = spark.createDataFrame(Seq(
(Seq("The", "the", "swift"), Seq("swift"))
)).toDF("raw", outputCol)

View file

@ -39,7 +39,7 @@ class StringIndexerSuite
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@ -63,8 +63,8 @@ class StringIndexerSuite
test("StringIndexerUnseen") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
val df = spark.createDataFrame(data).toDF("id", "label")
val df2 = spark.createDataFrame(data2).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@ -93,7 +93,7 @@ class StringIndexerSuite
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@ -114,12 +114,12 @@ class StringIndexerSuite
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")
.setOutputCol("labelIndex")
val df = sqlContext.range(0L, 10L).toDF()
val df = spark.range(0L, 10L).toDF()
assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
}
test("StringIndexerModel can't overwrite output column") {
val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
val indexer = new StringIndexer()
.setInputCol("input")
.setOutputCol("output")
@ -153,7 +153,7 @@ class StringIndexerSuite
test("IndexToString.transform") {
val labels = Array("a", "b", "c")
val df0 = sqlContext.createDataFrame(Seq(
val df0 = spark.createDataFrame(Seq(
(0, "a"), (1, "b"), (2, "c"), (0, "a")
)).toDF("index", "expected")
@ -180,7 +180,7 @@ class StringIndexerSuite
test("StringIndexer, IndexToString are inverses") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
@ -213,7 +213,7 @@ class StringIndexerSuite
test("StringIndexer metadata") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df = spark.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")

View file

@ -57,13 +57,13 @@ class RegexTokenizerSuite
.setPattern("\\w+|\\p{Punct}")
.setInputCol("rawText")
.setOutputCol("tokens")
val dataset0 = sqlContext.createDataFrame(Seq(
val dataset0 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer0, dataset0)
val dataset1 = sqlContext.createDataFrame(Seq(
val dataset1 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Array("punct"))
))
@ -73,7 +73,7 @@ class RegexTokenizerSuite
val tokenizer2 = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
val dataset2 = sqlContext.createDataFrame(Seq(
val dataset2 = spark.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Array("te,st.", "punct"))
))
@ -85,7 +85,7 @@ class RegexTokenizerSuite
.setInputCol("rawText")
.setOutputCol("tokens")
.setToLowercase(false)
val dataset = sqlContext.createDataFrame(Seq(
val dataset = spark.createDataFrame(Seq(
TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
TokenizerTestData("java scala", Array("java", "scala"))
))

View file

@ -57,7 +57,7 @@ class VectorAssemblerSuite
}
test("VectorAssembler") {
val df = sqlContext.createDataFrame(Seq(
val df = spark.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
)).toDF("id", "x", "y", "name", "z", "n")
val assembler = new VectorAssembler()
@ -70,7 +70,7 @@ class VectorAssemblerSuite
}
test("transform should throw an exception in case of unsupported type") {
val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
val assembler = new VectorAssembler()
.setInputCols(Array("a", "b", "c"))
.setOutputCol("features")
@ -87,7 +87,7 @@ class VectorAssemblerSuite
NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
NumericAttribute.defaultAttr.withName("salary")))
val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
.select(
col("browser").as("browser", browser.toMetadata()),
col("hour").as("hour", hour.toMetadata()),

View file

@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
checkPair(densePoints1Seq, sparsePoints1Seq)
checkPair(densePoints2Seq, sparsePoints2Seq)
densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
}
private def getIndexer: VectorIndexer =
@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("Cannot fit an empty DataFrame") {
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer
intercept[IllegalArgumentException] {
vectorIndexer.fit(rdd)

View file

@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
val df = sqlContext.createDataFrame(rdd,
val df = spark.createDataFrame(rdd,
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")

View file

@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("Word2Vec") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10
val numOfWords = sentence.split(" ").size
@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("getVectors") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("window size") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))

View file

@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@ -305,8 +305,8 @@ class ALSSuite
numUserBlocks: Int = 2,
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
@ -460,8 +460,8 @@ class ALSSuite
allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v)
}
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val model = als.fit(ratings.toDF())
// Test Estimator save/load
@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite {
// Generate test data
val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0)
// Implicitly test the cleaning of parents during ALS training
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val spark = SparkSession.builder
.master("local[2]")
.appName("ALSCleanerSuite")
.getOrCreate()
import spark.implicits._
val als = new ALS()
.setRank(1)
.setRegParam(1e-5)
@ -577,8 +580,8 @@ class ALSStorageSuite
}
test("default and non-default storage params set correct RDD StorageLevels") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val spark = this.spark
import spark.implicits._
val data = Seq(
(0, 0, 1.0),
(0, 1, 2.0),

Some files were not shown because too many files have changed in this diff Show more