diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 60a4a1d2ea..e0c4363597 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,18 +17,18 @@ package org.apache.spark.ml; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; /** @@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene */ public class JavaPipelineSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset dataset; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaPipelineSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaPipelineSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.createDataFrame(points, LabeledPoint.class); + dataset = spark.createDataFrame(points, LabeledPoint.class); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -63,10 +66,10 @@ public class JavaPipelineSuite { LogisticRegression lr = new LogisticRegression() .setFeaturesCol("scaledFeatures"); Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {scaler, lr}); + .setStages(new PipelineStage[]{scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - Dataset predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java index b74bbed231..15cde0d3c0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java @@ -17,8 +17,8 @@ package org.apache.spark.ml.attribute; -import org.junit.Test; import org.junit.Assert; +import org.junit.Test; public class JavaAttributeSuite { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 1f23682621..8b89991327 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -21,8 +21,6 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; - +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaDecisionTreeClassifierSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaDecisionTreeClassifierSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + for (String impurity : DecisionTreeClassifier.supportedImpurities()) { dt.setImpurity(impurity); } DecisionTreeClassificationModel model = dt.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 74841058a2..682371eb9e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaGBTClassifierSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaGBTClassifierSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable { .setMaxIter(3) .setStepSize(0.1) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String lossType: GBTClassifier.supportedLossTypes()) { + for (String lossType : GBTClassifier.supportedLossTypes()) { rf.setLossType(lossType); } GBTClassificationModel model = rf.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index e160a5a47e..e3ff68364e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -27,18 +27,17 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - +import org.apache.spark.sql.SparkSession; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaLogisticRegressionSuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset dataset; private transient JavaRDD datasetRDD; @@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable { @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - Dataset predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - Dataset predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); - for (Row r: predAllZero.collectAsList()) { + Dataset predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r : predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - Dataset predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; - for (Row r: predNotAllZero.collectAsList()) { + for (Row r : predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; } Assert.assertTrue(foundNonZero); // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); Assert.assertEquals(5, parent2.getMaxIter()); Assert.assertEquals(0.1, parent2.getRegParam(), eps); @@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - Dataset trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collectAsList()) { - Vector raw = (Vector)row.get(0); - Vector prob = (Vector)row.get(1); + Dataset trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row : trans1.collectAsList()) { + Vector raw = (Vector) row.get(0); + Vector prob = (Vector) row.get(1); Assert.assertEquals(raw.size(), 2); Assert.assertEquals(prob.size(), 2); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); @@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - Dataset trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collectAsList()) { + Dataset trans2 = spark.sql("SELECT prediction, probability FROM transformed"); + for (Row row : trans2.collectAsList()) { double pred = row.getDouble(0); - Vector prob = (Vector)row.get(1); - double probOfPred = prob.apply((int)pred); + Vector prob = (Vector) row.get(1); + double probOfPred = prob.apply((int) pred); for (int i = 0; i < prob.size(); ++i) { Assert.assertTrue(probOfPred >= prob.apply(i)); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index bc955f3cf6..b0624cea3e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -26,49 +26,49 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaMultilayerPerceptronClassifierSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; - sqlContext = null; + spark.stop(); + spark = null; } @Test public void testMLPC() { - Dataset dataFrame = sqlContext.createDataFrame( - jsc.parallelize(Arrays.asList( - new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), - new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), - new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), - LabeledPoint.class); + List data = Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)) + ); + Dataset dataFrame = spark.createDataFrame(data, LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() - .setLayers(new int[] {2, 5, 2}) + .setLayers(new int[]{2, 5, 2}) .setBlockSize(1) .setSeed(123L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); Dataset result = model.transform(dataFrame); List predictionAndLabels = result.select("prediction", "label").collectAsList(); - for (Row r: predictionAndLabels) { + for (Row r : predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 45101f286c..3fc3648627 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -26,13 +26,12 @@ import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType; public class JavaNaiveBayesSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } public void validatePrediction(Dataset predictionAndLabels) { @@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 00f4476841..486fbbd58c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,7 +20,6 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.List; -import org.apache.spark.sql.Row; import scala.collection.JavaConverters; import org.junit.After; @@ -30,56 +29,61 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; public class JavaOneVsRestSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient Dataset dataset; - private transient JavaRDD datasetRDD; + private transient SparkSession spark; + private transient JavaSparkContext jsc; + private transient Dataset dataset; + private transient JavaRDD datasetRDD; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); - jsql = new SQLContext(jsc); - int nPoints = 3; + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local") + .appName("JavaLOneVsRestSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); - // The following coefficients and xMean/xVariance are computed from iris dataset with - // lambda=0.2. - // As a result, we are drawing samples from probability distribution of an actual model. - double[] coefficients = { - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + int nPoints = 3; - double[] xMean = {5.843, 3.057, 3.758, 1.199}; - double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = JavaConverters.seqAsJavaListConverter( - generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) - ).asJava(); - datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - } + // The following coefficients and xMean/xVariance are computed from iris dataset with + // lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + double[] coefficients = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682}; - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List points = JavaConverters.seqAsJavaListConverter( + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) + ).asJava(); + datasetRDD = jsc.parallelize(points, 2); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); + } - @Test - public void oneVsRestDefaultParams() { - OneVsRest ova = new OneVsRest(); - ova.setClassifier(new LogisticRegression()); - Assert.assertEquals(ova.getLabelCol() , "label"); - Assert.assertEquals(ova.getPredictionCol() , "prediction"); - OneVsRestModel ovaModel = ova.fit(dataset); - Dataset predictions = ovaModel.transform(dataset).select("label", "prediction"); - predictions.collectAsList(); - Assert.assertEquals(ovaModel.getLabelCol(), "label"); - Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); - } + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol(), "label"); + Assert.assertEquals(ova.getPredictionCol(), "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + Dataset predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol(), "prediction"); + } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 4f40fd65b9..e3855662fb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaRandomForestClassifierSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRandomForestClassifierSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable { .setSeed(1234) .setNumTrees(3) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: RandomForestClassifier.supportedImpurities()) { + for (String impurity : RandomForestClassifier.supportedImpurities()) { rf.setImpurity(impurity); } - for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; - for (String strategy: realStrategies) { + for (String strategy : realStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; - for (String strategy: integerStrategies) { + for (String strategy : integerStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; - for (String strategy: invalidStrategies) { + for (String strategy : invalidStrategies) { try { rf.setFeatureSubsetStrategy(strategy); Assert.fail("Expected exception to be thrown for invalid strategies"); diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index a3fcdb54ee..3ab09ac27d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -21,37 +21,37 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import org.apache.spark.api.java.JavaSparkContext; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaKMeansSuite implements Serializable { private transient int k = 5; - private transient JavaSparkContext sc; private transient Dataset dataset; - private transient SQLContext sql; + private transient SparkSession spark; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeansSuite"); - sql = new SQLContext(sc); - - dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); + spark = SparkSession.builder() + .master("local") + .appName("JavaKMeansSuite") + .getOrCreate(); + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable { Dataset transformed = model.transform(dataset); List columns = Arrays.asList(transformed.columns()); List expectedColumns = Arrays.asList("features", "prediction"); - for (String column: expectedColumns) { + for (String column : expectedColumns) { assertTrue(columns.contains(column)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 77e3a489a9..a96b43de15 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -25,40 +25,40 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class JavaBucketizerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaBucketizerSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame( + Dataset dataset = spark.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index ed1ad4c3a3..06482d8f0d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -21,43 +21,44 @@ import java.util.Arrays; import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class JavaDCTSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaDCTSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaDCTSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test public void javaCompatibilityTest() { - double[] input = new double[] {1D, 2D, 3D, 4D}; - Dataset dataset = jsql.createDataFrame( + double[] input = new double[]{1D, 2D, 3D, 4D}; + Dataset dataset = spark.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 6e2cc7e887..0e21d4a94f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -25,12 +25,11 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType; public class JavaHashingTFSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaHashingTFSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaHashingTFSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -65,7 +65,7 @@ public class JavaHashingTFSuite { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - Dataset sentenceData = jsql.createDataFrame(data, schema); + Dataset sentenceData = spark.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index 5bbd9634b2..04b2897b18 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -23,27 +23,30 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaNormalizerSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaNormalizerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -54,7 +57,7 @@ public class JavaNormalizerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - Dataset dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 1389d17e7e..32f6b4375e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -28,31 +28,34 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaPCASuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaPCASuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } public static class VectorPair implements Serializable { @@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable { } ); - Dataset df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset df = spark.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 6a8bb64801..8f726077a2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class JavaPolynomialExpansionSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaPolynomialExpansionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After @@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite { ) ); - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); List pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") .collectAsList(); for (Row r : pairs) { - double[] polyFeatures = ((Vector)r.get(0)).toArray(); - double[] expected = ((Vector)r.get(1)).toArray(); + double[] polyFeatures = ((Vector) r.get(0)).toArray(); + double[] expected = ((Vector) r.get(1)).toArray(); Assert.assertArrayEquals(polyFeatures, expected, 1e-1); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 3f6fc333e4..c7397bdd68 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaStandardScalerSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaStandardScalerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -54,7 +57,7 @@ public class JavaStandardScalerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - Dataset dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset dataFrame = spark.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index bdcbde5e26..2b156f3bca 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -24,11 +24,10 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType; public class JavaStopWordsRemoverSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaStopWordsRemoverSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite { RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) ); - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, - Metadata.empty()) + Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); remover.transform(dataset).collect(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 431779cd2e..52c0bde8f3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -25,40 +25,42 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SparkConf; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.types.DataTypes.*; public class JavaStringIndexerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); - sqlContext = new SQLContext(jsc); + SparkConf sparkConf = new SparkConf(); + sparkConf.setMaster("local"); + sparkConf.setAppName("JavaStringIndexerSuite"); + + spark = SparkSession.builder().config(sparkConf).getOrCreate(); } @After public void tearDown() { - jsc.stop(); - sqlContext = null; + spark.stop(); + spark = null; } @Test public void testStringIndexer() { - StructType schema = createStructType(new StructField[] { + StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); - Dataset dataset = sqlContext.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") @@ -70,7 +72,9 @@ public class JavaStringIndexerSuite { output.orderBy("id").select("id", "labelIndex").collectAsList()); } - /** An alias for RowFactory.create. */ + /** + * An alias for RowFactory.create. + */ private Row cr(Object... values) { return RowFactory.create(values); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 83d16cbd0e..0bac2839e1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaTokenizerSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaTokenizerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -59,10 +62,10 @@ public class JavaTokenizerSuite { JavaRDD rdd = jsc.parallelize(Arrays.asList( - new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), - new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"}) )); - Dataset dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset dataset = spark.createDataFrame(rdd, TokenizerTestData.class); List pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index e45e198043..8774cd0c69 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -24,36 +24,39 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SparkConf; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.types.DataTypes.*; public class JavaVectorAssemblerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); - sqlContext = new SQLContext(jsc); + SparkConf sparkConf = new SparkConf(); + sparkConf.setMaster("local"); + sparkConf.setAppName("JavaVectorAssemblerSuite"); + + spark = SparkSession.builder().config(sparkConf).getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test public void testVectorAssembler() { - StructType schema = createStructType(new StructField[] { + StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("x", DoubleType, false), createStructField("y", new VectorUDT(), false), @@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite { }); Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", - Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - Dataset dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L); + Dataset dataset = spark.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"x", "y", "z", "n"}) + .setInputCols(new String[]{"x", "y", "z", "n"}) .setOutputCol("features"); Dataset output = assembler.transform(dataset); Assert.assertEquals( - Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}), output.select("features").first().getAs(0)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index fec6cac8be..c386c9a45b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaVectorIndexerSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaVectorIndexerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable { new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 4.0)) ); - SQLContext sqlContext = new SQLContext(sc); - Dataset data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index e2da11183b..59ad3c2f61 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -25,7 +25,6 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; @@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructType; public class JavaVectorSlicerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaVectorSlicerSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite { ); Dataset dataset = - jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + spark.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 7517b70cc9..392aabc96d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -24,28 +24,28 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; public class JavaWord2VecSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaWord2VecSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -53,7 +53,7 @@ public class JavaWord2VecSuite { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - Dataset documentDF = sqlContext.createDataFrame( + Dataset documentDF = spark.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -68,8 +68,8 @@ public class JavaWord2VecSuite { Word2VecModel model = word2Vec.fit(documentDF); Dataset result = model.transform(documentDF); - for (Row r: result.select("result").collectAsList()) { - double[] polyFeatures = ((Vector)r.get(0)).toArray(); + for (Row r : result.select("result").collectAsList()) { + double[] polyFeatures = ((Vector) r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index fa777f3d42..a5b5dd4088 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -25,23 +25,29 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; /** * Test Param and related classes in Java */ public class JavaParamsSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaParamsSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaParamsSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -51,7 +57,7 @@ public class JavaParamsSuite { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); - Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 06f7fbb86e..1ad5f7a442 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams { } private IntParam myIntParam_; - public IntParam myIntParam() { return myIntParam_; } - public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } + public IntParam myIntParam() { + return myIntParam_; + } + + public int getMyIntParam() { + return (Integer) getOrDefault(myIntParam_); + } public JavaTestParams setMyIntParam(int value) { set(myIntParam_, value); @@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams { } private DoubleParam myDoubleParam_; - public DoubleParam myDoubleParam() { return myDoubleParam_; } - public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } + public DoubleParam myDoubleParam() { + return myDoubleParam_; + } + + public double getMyDoubleParam() { + return (Double) getOrDefault(myDoubleParam_); + } public JavaTestParams setMyDoubleParam(double value) { set(myDoubleParam_, value); @@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams { } private Param myStringParam_; - public Param myStringParam() { return myStringParam_; } - public String getMyStringParam() { return getOrDefault(myStringParam_); } + public Param myStringParam() { + return myStringParam_; + } + + public String getMyStringParam() { + return getOrDefault(myStringParam_); + } public JavaTestParams setMyStringParam(String value) { set(myStringParam_, value); @@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams { } private DoubleArrayParam myDoubleArrayParam_; - public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } - public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + public DoubleArrayParam myDoubleArrayParam() { + return myDoubleArrayParam_; + } + + public double[] getMyDoubleArrayParam() { + return getOrDefault(myDoubleArrayParam_); + } public JavaTestParams setMyDoubleArrayParam(double[] value) { set(myDoubleArrayParam_, value); @@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams { setDefault(myIntParam(), 1); setDefault(myDoubleParam(), 0.5); - setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0}); } @Override diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index fa3b28ed4f..bbd59a04ec 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaDecisionTreeRegressorSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaDecisionTreeRegressorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + for (String impurity : DecisionTreeRegressor.supportedImpurities()) { dt.setImpurity(impurity); } DecisionTreeRegressionModel model = dt.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 8413ea0e0a..5370b58e8f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaGBTRegressorSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaGBTRegressorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable { .setMaxIter(3) .setStepSize(0.1) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String lossType: GBTRegressor.supportedLossTypes()) { + for (String lossType : GBTRegressor.supportedLossTypes()) { rf.setLossType(lossType); } GBTRegressionModel model = rf.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 9f817515eb..00c59f08b6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; - +import org.apache.spark.sql.SparkSession; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaLinearRegressionSuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset dataset; private transient JavaRDD datasetRDD; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLinearRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } @@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable { assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - Dataset predictions = jsql.sql("SELECT label, prediction FROM prediction"); + Dataset predictions = spark.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); @@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable { public void linearRegressionWithSetters() { // Set params, train, and check as many params as we can. LinearRegression lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0).setSolver("l-bfgs"); + .setMaxIter(10) + .setRegParam(1.0).setSolver("l-bfgs"); LinearRegressionModel model = lr.fit(dataset); LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); @@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable { // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = - lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); LinearRegression parent2 = (LinearRegression) model2.parent(); assertEquals(5, parent2.getMaxIter()); assertEquals(0.1, parent2.getRegParam(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 38b895f1fd..fdb41ffc10 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -28,27 +28,33 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.tree.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaRandomForestRegressorSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRandomForestRegressorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable { .setSeed(1234) .setNumTrees(3) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: RandomForestRegressor.supportedImpurities()) { + for (String impurity : RandomForestRegressor.supportedImpurities()) { rf.setImpurity(impurity); } - for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; - for (String strategy: realStrategies) { + for (String strategy : realStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; - for (String strategy: integerStrategies) { + for (String strategy : integerStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; - for (String strategy: invalidStrategies) { + for (String strategy : invalidStrategies) { try { rf.setFeatureSubsetStrategy(strategy); Assert.fail("Expected exception to be thrown for invalid strategies"); diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 1c18b2b266..058f2ddafd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -28,12 +28,11 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; @@ -41,16 +40,17 @@ import org.apache.spark.util.Utils; * Test LibSVMRelation in Java. */ public class JavaLibSVMRelationSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; private File tempDir; private String path; @Before public void setUp() throws IOException { - jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLibSVMRelationSuite") + .getOrCreate(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); @@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite { @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; Utils.deleteRecursively(tempDir); } @Test public void verifyLibSVMDF() { - Dataset dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset dataset = spark.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 24b0097454..8b4d034ffe 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaCrossValidatorSuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset dataset; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaCrossValidatorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After @@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable { public void crossValidationWithLogisticRegression() { LogisticRegression lr = new LogisticRegression(); ParamMap[] lrParamMaps = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) - .addGrid(lr.maxIter(), new int[] {0, 10}) + .addGrid(lr.regParam(), new double[]{0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[]{0, 10}) .build(); BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); CrossValidator cv = new CrossValidator() diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 928301523f..878bc66ee3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -37,4 +37,5 @@ object IdentifiableSuite { class Test(override val uid: String) extends Identifiable { def this() = this(Identifiable.randomUID("test")) } + } diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 01ff1ea658..7151e27cde 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -27,31 +27,34 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; public class JavaDefaultReadWriteSuite { JavaSparkContext jsc = null; - SQLContext sqlContext = null; + SparkSession spark = null; File tempDir = null; @Before public void setUp() { - jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); SQLContext.clearActive(); - sqlContext = new SQLContext(jsc); - SQLContext.setActive(sqlContext); + spark = SparkSession.builder() + .master("local[2]") + .appName("JavaDefaultReadWriteSuite") + .getOrCreate(); + SQLContext.setActive(spark.wrapped()); + tempDir = Utils.createTempDir( System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } @After public void tearDown() { - sqlContext = null; SQLContext.clearActive(); - if (jsc != null) { - jsc.stop(); - jsc = null; + if (spark != null) { + spark.stop(); + spark = null; } Utils.deleteRecursively(tempDir); } @@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite { } catch (IOException e) { // expected } - instance.write().context(sqlContext).overwrite().save(outputPath); + instance.write().context(spark.wrapped()).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index 862221d487..2f10d14da5 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -27,26 +27,31 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; - import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.SparkSession; public class JavaLogisticRegressionSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List validationData, LogisticRegressionModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numAccurate++; @@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); lrImpl.setIntercept(true); lrImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); + .setRegParam(1.0) + .setNumIterations(100); LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable { double A = 0.0; double B = -2.5; - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionModel model = LogisticRegressionWithSGD.train( - testRDD.rdd(), 100, 1.0, 1.0); + testRDD.rdd(), 100, 1.0, 1.0); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 3771c0ea7a..5e212e2fc5 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.SparkSession; public class JavaNaiveBayesSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaNaiveBayesSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaNaiveBayesSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } private static final List POINTS = Arrays.asList( @@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable { private int validatePrediction(List points, NaiveBayesModel model) { int correct = 0; - for (LabeledPoint p: points) { + for (LabeledPoint p : points) { if (model.predict(p.features()) == p.label()) { correct += 1; } @@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void runUsingConstructor() { - JavaRDD testRDD = sc.parallelize(POINTS, 2).cache(); + JavaRDD testRDD = jsc.parallelize(POINTS, 2).cache(); NaiveBayes nb = new NaiveBayes().setLambda(1.0); NaiveBayesModel model = nb.run(testRDD.rdd()); @@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void runUsingStaticMethods() { - JavaRDD testRDD = sc.parallelize(POINTS, 2).cache(); + JavaRDD testRDD = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd()); int numAccurate1 = validatePrediction(POINTS, model1); @@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void testPredictJavaRDD() { - JavaRDD examples = sc.parallelize(POINTS, 2).cache(); + JavaRDD examples = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model = NaiveBayes.train(examples.rdd()); JavaRDD vectors = examples.map(new Function() { @Override public Vector call(LabeledPoint v) throws Exception { return v.features(); - }}); + } + }); JavaRDD predictions = model.predict(vectors); // Should be able to get the first prediction. predictions.first(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 31b9f3e8d4..2a090c054f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -28,24 +28,30 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.SparkSession; public class JavaSVMSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaSVMSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaSVMSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List validationData, SVMModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numAccurate++; @@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0}; - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMWithSGD svmSGDImpl = new SVMWithSGD(); svmSGDImpl.setIntercept(true); svmSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); + .setRegParam(1.0) + .setNumIterations(100); SVMModel model = svmSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable { double A = 0.0; double[] weights = {-1.5, 1.0}; - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java index a714620ff7..7f29b05047 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; import com.google.common.collect.Lists; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaBisectingKMeansSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + spark = SparkSession.builder() + .master("local") + .appName("JavaBisectingKMeansSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void twoDimensionalData() { - JavaRDD points = sc.parallelize(Lists.newArrayList( + JavaRDD points = jsc.parallelize(Lists.newArrayList( Vectors.dense(4, -1), Vectors.dense(4, 1), - Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + Vectors.sparse(2, new int[]{0}, new double[]{1.0}) ), 2); BisectingKMeans bkm = new BisectingKMeans() @@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable { .setSeed(1L); BisectingKMeansModel model = bkm.run(points); Assert.assertEquals(3, model.k()); - Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); - for (ClusteringTreeNode child: model.root().children()) { + Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child : model.root().children()) { double[] center = child.center().toArray(); if (center[0] > 2) { Assert.assertEquals(2, child.size()); - Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12); } else { Assert.assertEquals(1, child.size()); - Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12); } } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 123f78da54..20edd08a21 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -21,29 +21,35 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertEquals; - import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaGaussianMixtureSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaGaussianMixture"); + spark = SparkSession.builder() + .master("local") + .appName("JavaGaussianMixture") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable { Vectors.dense(1.0, 4.0, 6.0) ); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) .run(data); assertEquals(model.gaussians().length, 2); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index ad06676c72..4e5b87f588 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -21,28 +21,35 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaKMeansSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeans"); + spark = SparkSession.builder() + .master("local") + .appName("JavaKMeans") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); @@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); @@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable { Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) ); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); JavaRDD predictions = model.predict(data); // Should be able to get the first prediction. diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index db19b309f6..f16585aff4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -27,37 +27,42 @@ import scala.Tuple3; import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaLDASuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLDA"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLDASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + ArrayList> tinyCorpus = new ArrayList<>(); for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(), - LDASuite.tinyCorpus()[i]._2())); + tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } - JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); + JavaRDD> tmpCorpus = jsc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); @@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable { public Boolean call(Tuple2 tuple2) { return Vectors.norm(tuple2._2(), 1.0) != 0.0; } - }); + }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); // Check: javaTopTopicsPerDocuments @@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable { @Test public void localLdaMethods() { - JavaRDD> docs = sc.parallelize(toyData, 2); + JavaRDD> docs = jsc.parallelize(toyData, 2); JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); // check: topicDistributions @@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable { // check: logLikelihood. ArrayList> docsSingleWord = new ArrayList<>(); docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0))); - JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); } @@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable { private static int tinyVocabSize = LDASuite.tinyVocabSize(); private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; private LocalLDAModel toyModel = LDASuite.toyModel(); private ArrayList> toyData = LDASuite.javaToyData(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index 62edbd3a29..d1d618f7de 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -27,8 +27,6 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.apache.spark.streaming.JavaTestUtils.*; - import org.apache.spark.SparkConf; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; public class JavaStreamingKMeansSuite implements Serializable { diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java index fa4d334801..6a096d6386 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -31,27 +31,34 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaRankingMetricsSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; private transient JavaRDD, List>> predictionAndLabels; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); - predictionAndLabels = sc.parallelize(Arrays.asList( + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + + predictionAndLabels = jsc.parallelize(Arrays.asList( Tuple2$.MODULE$.apply( Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), Tuple2$.MODULE$.apply( - Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), + Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), Tuple2$.MODULE$.apply( - Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); + Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index 8a320afa4b..de50fb8c4f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -29,19 +29,25 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.SparkSession; public class JavaTfIdfSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaTfIdfSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Arrays.asList( + JavaRDD> documents = jsc.parallelize(Arrays.asList( Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is still a sentence".split(" "))), 2); @@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable { JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); List localTfIdfs = tfIdfs.collect(); int indexOfThis = tf.indexOf("this"); - for (Vector v: localTfIdfs) { + for (Vector v : localTfIdfs) { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } @@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Arrays.asList( + JavaRDD> documents = jsc.parallelize(Arrays.asList( Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is still a sentence".split(" "))), 2); @@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable { JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); List localTfIdfs = tfIdfs.collect(); int indexOfThis = tf.indexOf("this"); - for (Vector v: localTfIdfs) { + for (Vector v : localTfIdfs) { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index e13ed07e28..64885cc842 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -21,9 +21,10 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import com.google.common.base.Strings; + import scala.Tuple2; -import com.google.common.base.Strings; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -31,19 +32,25 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaWord2VecSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaWord2VecSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable { String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); List words = Arrays.asList(sentence.split(" ")); List> localDoc = Arrays.asList(words, words); - JavaRDD> doc = sc.parallelize(localDoc); + JavaRDD> doc = jsc.parallelize(localDoc); Word2Vec word2vec = new Word2Vec() .setVectorSize(10) .setSeed(42L); diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 2bef7a8609..fdc19a5b3d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -26,32 +26,37 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; +import org.apache.spark.sql.SparkSession; public class JavaAssociationRulesSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); + spark = SparkSession.builder() + .master("local") + .appName("JavaAssociationRulesSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void runAssociationRules() { @SuppressWarnings("unchecked") - JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( - new FreqItemset(new String[] {"a"}, 15L), - new FreqItemset(new String[] {"b"}, 35L), - new FreqItemset(new String[] {"a", "b"}, 12L) + JavaRDD> freqItemsets = jsc.parallelize(Arrays.asList( + new FreqItemset(new String[]{"a"}, 15L), + new FreqItemset(new String[]{"b"}, 35L), + new FreqItemset(new String[]{"a", "b"}, 12L) )); JavaRDD> results = (new AssociationRules()).run(freqItemsets); } } - diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 916fff14a7..f235251e61 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -22,34 +22,41 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; public class JavaFPGrowthSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); + spark = SparkSession.builder() + .master("local") + .appName("JavaFPGrowth") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void runFPGrowth() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Arrays.asList( + JavaRDD> rdd = jsc.parallelize(Arrays.asList( Arrays.asList("r z h k p".split(" ")), Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("s x o n r".split(" ")), @@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable { List> freqItemsets = model.freqItemsets().toJavaRDD().collect(); assertEquals(18, freqItemsets.size()); - for (FPGrowth.FreqItemset itemset: freqItemsets) { + for (FPGrowth.FreqItemset itemset : freqItemsets) { // Test return types. List items = itemset.javaItems(); long freq = itemset.freq(); @@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable { public void runFPGrowthSaveLoad() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Arrays.asList( + JavaRDD> rdd = jsc.parallelize(Arrays.asList( Arrays.asList("r z h k p".split(" ")), Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("s x o n r".split(" ")), @@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable { String outputPath = tempDir.getPath(); try { - model.save(sc.sc(), outputPath); + model.save(spark.sparkContext(), outputPath); @SuppressWarnings("unchecked") FPGrowthModel newModel = - (FPGrowthModel) FPGrowthModel.load(sc.sc(), outputPath); + (FPGrowthModel) FPGrowthModel.load(spark.sparkContext(), outputPath); List> freqItemsets = newModel.freqItemsets().toJavaRDD() .collect(); assertEquals(18, freqItemsets.size()); - for (FPGrowth.FreqItemset itemset: freqItemsets) { + for (FPGrowth.FreqItemset itemset : freqItemsets) { // Test return types. List items = itemset.javaItems(); long freq = itemset.freq(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 8a67793abc..bf7f1fc71b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -29,25 +29,31 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; public class JavaPrefixSpanSuite { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaPrefixSpan"); + spark = SparkSession.builder() + .master("local") + .appName("JavaPrefixSpan") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void runPrefixSpan() { - JavaRDD>> sequences = sc.parallelize(Arrays.asList( + JavaRDD>> sequences = jsc.parallelize(Arrays.asList( Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), @@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite { List> localFreqSeqs = freqSeqs.collect(); Assert.assertEquals(5, localFreqSeqs.size()); // Check that each frequent sequence could be materialized. - for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + for (PrefixSpan.FreqSequence freqSeq : localFreqSeqs) { List> seq = freqSeq.javaSequence(); long freq = freqSeq.freq(); } @@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite { @Test public void runPrefixSpanSaveLoad() { - JavaRDD>> sequences = sc.parallelize(Arrays.asList( + JavaRDD>> sequences = jsc.parallelize(Arrays.asList( Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), @@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite { String outputPath = tempDir.getPath(); try { - model.save(sc.sc(), outputPath); - PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + model.save(spark.sparkContext(), outputPath); + PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath); JavaRDD> freqSeqs = newModel.freqSequences().toJavaRDD(); List> localFreqSeqs = freqSeqs.collect(); Assert.assertEquals(5, localFreqSeqs.size()); // Check that each frequent sequence could be materialized. - for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + for (PrefixSpan.FreqSequence freqSeq : localFreqSeqs) { List> seq = freqSeq.javaSequence(); long freq = freqSeq.freq(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 8beea102ef..92fc57871c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -17,147 +17,149 @@ package org.apache.spark.mllib.linalg; -import static org.junit.Assert.*; -import org.junit.Test; - import java.io.Serializable; import java.util.Random; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + public class JavaMatricesSuite implements Serializable { - @Test - public void randMatrixConstruction() { - Random rng = new Random(24); - Matrix r = Matrices.rand(3, 4, rng); - rng.setSeed(24); - DenseMatrix dr = DenseMatrix.rand(3, 4, rng); - assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + @Test + public void randMatrixConstruction() { + Random rng = new Random(24); + Matrix r = Matrices.rand(3, 4, rng); + rng.setSeed(24); + DenseMatrix dr = DenseMatrix.rand(3, 4, rng); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); - rng.setSeed(24); - Matrix rn = Matrices.randn(3, 4, rng); - rng.setSeed(24); - DenseMatrix drn = DenseMatrix.randn(3, 4, rng); - assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); + rng.setSeed(24); + Matrix rn = Matrices.randn(3, 4, rng); + rng.setSeed(24); + DenseMatrix drn = DenseMatrix.randn(3, 4, rng); + assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); - rng.setSeed(24); - Matrix s = Matrices.sprand(3, 4, 0.5, rng); - rng.setSeed(24); - SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); - assertArrayEquals(s.toArray(), sr.toArray(), 0.0); + rng.setSeed(24); + Matrix s = Matrices.sprand(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); + assertArrayEquals(s.toArray(), sr.toArray(), 0.0); - rng.setSeed(24); - Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); - rng.setSeed(24); - SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); - assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); - } + rng.setSeed(24); + Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); + assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); + } - @Test - public void identityMatrixConstruction() { - Matrix r = Matrices.eye(2); - DenseMatrix dr = DenseMatrix.eye(2); - SparseMatrix sr = SparseMatrix.speye(2); - assertArrayEquals(r.toArray(), dr.toArray(), 0.0); - assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); - assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); - } + @Test + public void identityMatrixConstruction() { + Matrix r = Matrices.eye(2); + DenseMatrix dr = DenseMatrix.eye(2); + SparseMatrix sr = SparseMatrix.speye(2); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); + assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); + } - @Test - public void diagonalMatrixConstruction() { - Vector v = Vectors.dense(1.0, 0.0, 2.0); - Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); + @Test + public void diagonalMatrixConstruction() { + Vector v = Vectors.dense(1.0, 0.0, 2.0); + Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); - Matrix m = Matrices.diag(v); - Matrix sm = Matrices.diag(sv); - DenseMatrix d = DenseMatrix.diag(v); - DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.spdiag(v); - SparseMatrix ss = SparseMatrix.spdiag(sv); + Matrix m = Matrices.diag(v); + Matrix sm = Matrices.diag(sv); + DenseMatrix d = DenseMatrix.diag(v); + DenseMatrix sd = DenseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); - assertArrayEquals(m.toArray(), sm.toArray(), 0.0); - assertArrayEquals(d.toArray(), sm.toArray(), 0.0); - assertArrayEquals(d.toArray(), sd.toArray(), 0.0); - assertArrayEquals(sd.toArray(), s.toArray(), 0.0); - assertArrayEquals(s.toArray(), ss.toArray(), 0.0); - assertArrayEquals(s.values(), ss.values(), 0.0); - assertEquals(2, s.values().length); - assertEquals(2, ss.values().length); - assertEquals(4, s.colPtrs().length); - assertEquals(4, ss.colPtrs().length); - } + assertArrayEquals(m.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sd.toArray(), 0.0); + assertArrayEquals(sd.toArray(), s.toArray(), 0.0); + assertArrayEquals(s.toArray(), ss.toArray(), 0.0); + assertArrayEquals(s.values(), ss.values(), 0.0); + assertEquals(2, s.values().length); + assertEquals(2, ss.values().length); + assertEquals(4, s.colPtrs().length); + assertEquals(4, ss.colPtrs().length); + } - @Test - public void zerosMatrixConstruction() { - Matrix z = Matrices.zeros(2, 2); - Matrix one = Matrices.ones(2, 2); - DenseMatrix dz = DenseMatrix.zeros(2, 2); - DenseMatrix done = DenseMatrix.ones(2, 2); + @Test + public void zerosMatrixConstruction() { + Matrix z = Matrices.zeros(2, 2); + Matrix one = Matrices.ones(2, 2); + DenseMatrix dz = DenseMatrix.zeros(2, 2); + DenseMatrix done = DenseMatrix.ones(2, 2); - assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); - assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); - assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); - assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); - } + assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + } - @Test - public void sparseDenseConversion() { - int m = 3; - int n = 2; - double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; - double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; - int[] colPtrs = new int[]{0, 2, 4}; - int[] rowIndices = new int[]{0, 1, 1, 2}; + @Test + public void sparseDenseConversion() { + int m = 3; + int n = 2; + double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; + double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; + int[] colPtrs = new int[]{0, 2, 4}; + int[] rowIndices = new int[]{0, 1, 1, 2}; - SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); - DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); + SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); + DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); - SparseMatrix spMat2 = deMat1.toSparse(); - DenseMatrix deMat2 = spMat1.toDense(); + SparseMatrix spMat2 = deMat1.toSparse(); + DenseMatrix deMat2 = spMat1.toDense(); - assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); - assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); - } + assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); + assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); + } - @Test - public void concatenateMatrices() { - int m = 3; - int n = 2; + @Test + public void concatenateMatrices() { + int m = 3; + int n = 2; - Random rng = new Random(42); - SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); - rng.setSeed(42); - DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); - Matrix deMat2 = Matrices.eye(3); - Matrix spMat2 = Matrices.speye(3); - Matrix deMat3 = Matrices.eye(2); - Matrix spMat3 = Matrices.speye(2); + Random rng = new Random(42); + SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); + rng.setSeed(42); + DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); + Matrix deMat2 = Matrices.eye(3); + Matrix spMat2 = Matrices.speye(3); + Matrix deMat3 = Matrices.eye(2); + Matrix spMat3 = Matrices.speye(2); - Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); - Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); - Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); - Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); + Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); + Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); + Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); + Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); - assertEquals(3, deHorz1.numRows()); - assertEquals(3, deHorz2.numRows()); - assertEquals(3, deHorz3.numRows()); - assertEquals(3, spHorz.numRows()); - assertEquals(5, deHorz1.numCols()); - assertEquals(5, deHorz2.numCols()); - assertEquals(5, deHorz3.numCols()); - assertEquals(5, spHorz.numCols()); + assertEquals(3, deHorz1.numRows()); + assertEquals(3, deHorz2.numRows()); + assertEquals(3, deHorz3.numRows()); + assertEquals(3, spHorz.numRows()); + assertEquals(5, deHorz1.numCols()); + assertEquals(5, deHorz2.numCols()); + assertEquals(5, deHorz3.numCols()); + assertEquals(5, spHorz.numCols()); - Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); - Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); - Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); - Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); + Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); + Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); + Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); + Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); - assertEquals(5, deVert1.numRows()); - assertEquals(5, deVert2.numRows()); - assertEquals(5, deVert3.numRows()); - assertEquals(5, spVert.numRows()); - assertEquals(2, deVert1.numCols()); - assertEquals(2, deVert2.numCols()); - assertEquals(2, deVert3.numCols()); - assertEquals(2, spVert.numCols()); - } + assertEquals(5, deVert1.numRows()); + assertEquals(5, deVert2.numRows()); + assertEquals(5, deVert3.numRows()); + assertEquals(5, spVert.numRows()); + assertEquals(2, deVert1.numCols()); + assertEquals(2, deVert2.numCols()); + assertEquals(2, deVert3.numCols()); + assertEquals(2, spVert.numCols()); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 4ba8e543a9..817b962c75 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg; import java.io.Serializable; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; + import scala.Tuple2; import org.junit.Test; -import static org.junit.Assert.*; public class JavaVectorsSuite implements Serializable { @@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable { public void sparseArrayConstruction() { @SuppressWarnings("unchecked") Vector v = Vectors.sparse(3, Arrays.asList( - new Tuple2<>(0, 2.0), - new Tuple2<>(2, 3.0))); + new Tuple2<>(0, 2.0), + new Tuple2<>(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index be58691f4d..b449108a9b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -20,29 +20,35 @@ package org.apache.spark.mllib.random; import java.io.Serializable; import java.util.Arrays; -import org.apache.spark.api.java.JavaRDD; -import org.junit.Assert; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.random.RandomRDDs.*; public class JavaRandomRDDsSuite { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRandomRDDsSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); - JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); - JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m); + JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p); + JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); - JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); - JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m); + JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p); + JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); - JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); - JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m); + JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p); + JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); - JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); - JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m); + JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p); + JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); - JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); - JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m); + JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p); + JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite { @Test public void testGammaRDD() { double shape = 1.0; - double scale = 2.0; + double jscale = 2.0; long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); - JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); - JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m); + JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p); + JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); - JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); - JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = uniformJavaVectorRDD(jsc, m, n); + JavaRDD rdd2 = uniformJavaVectorRDD(jsc, m, n, p); + JavaRDD rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); - JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); - JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = normalJavaVectorRDD(jsc, m, n); + JavaRDD rdd2 = normalJavaVectorRDD(jsc, m, n, p); + JavaRDD rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); - JavaRDD rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); - JavaRDD rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n); + JavaRDD rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p); + JavaRDD rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); - JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); - JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = poissonJavaVectorRDD(jsc, mean, m, n); + JavaRDD rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p); + JavaRDD rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); - JavaRDD rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); - JavaRDD rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n); + JavaRDD rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p); + JavaRDD rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite { @SuppressWarnings("unchecked") public void testGammaVectorRDD() { double shape = 1.0; - double scale = 2.0; + double jscale = 2.0; long m = 100L; int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); - JavaRDD rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); - JavaRDD rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n); + JavaRDD rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p); + JavaRDD rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite { long seed = 1L; int numPartitions = 0; StringGenerator gen = new StringGenerator(); - JavaRDD rdd1 = randomJavaRDD(sc, gen, size); - JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); - JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = randomJavaRDD(jsc, gen, size); + JavaRDD rdd2 = randomJavaRDD(jsc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(size, rdd.count()); Assert.assertEquals(2, rdd.first().length()); } @@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); - JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); - JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = randomJavaVectorRDD(jsc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator, Serializable { public String nextValue() { return "42"; } + @Override public StringGenerator copy() { return new StringGenerator(); } + @Override public void setSeed(long seed) { } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index d0bf7f556d..aa784054d5 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -32,40 +32,46 @@ import org.junit.Test; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaALSSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaALS"); + spark = SparkSession.builder() + .master("local") + .appName("JavaALS") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } private void validatePrediction( - MatrixFactorizationModel model, - int users, - int products, - double[] trueRatings, - double matchThreshold, - boolean implicitPrefs, - double[] truePrefs) { + MatrixFactorizationModel model, + int users, + int products, + double[] trueRatings, + double matchThreshold, + boolean implicitPrefs, + double[] truePrefs) { List> localUsersProducts = new ArrayList<>(users * products); - for (int u=0; u < users; ++u) { - for (int p=0; p < products; ++p) { + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { localUsersProducts.add(new Tuple2<>(u, p)); } } - JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); + JavaPairRDD usersProducts = jsc.parallelizePairs(localUsersProducts); List predictedRatings = model.predict(usersProducts).collect(); Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (Rating r: predictedRatings) { + for (Rating r : predictedRatings) { double prediction = r.rating(); double correct = trueRatings[r.product() * users + r.user()]; Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", @@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable { // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (Rating r: predictedRatings) { + for (Rating r : predictedRatings) { double prediction = r.rating(); double truePref = truePrefs[r.product() * users + r.user()]; double confidence = 1.0 + @@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable { int users = 50; int products = 100; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable { int users = 100; int products = 200; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) @@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable { int users = 80; int products = 160; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable { int users = 100; int products = 200; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) @@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable { int users = 80; int products = 160; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) @@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable { int users = 200; int products = 50; List testData = ALSSuite.generateRatingsAsJava( - users, products, features, 0.7, true, false)._1(); - JavaRDD data = sc.parallelize(testData); + users, products, features, 0.7, true, false)._1(); + JavaRDD data = jsc.parallelize(testData); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) @@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable { private static void validateRecommendations(Rating[] recommendations, int howMany) { Assert.assertEquals(howMany, recommendations.length); for (int i = 1; i < recommendations.length; i++) { - Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating()); } Assert.assertTrue(recommendations[0].rating() > 0.7); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 3db9b39e74..8b05675d65 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -32,15 +32,17 @@ import org.junit.Test; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaIsotonicRegressionSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; private static List> generateIsotonicInput(double[] labels) { List> input = new ArrayList<>(labels.length); for (int i = 1; i <= labels.length; i++) { - input.add(new Tuple3<>(labels[i-1], (double) i, 1.0)); + input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0)); } return input; @@ -48,20 +50,24 @@ public class JavaIsotonicRegressionSuite implements Serializable { private IsotonicRegressionModel runIsotonicRegression(double[] labels) { JavaRDD> trainRDD = - sc.parallelize(generateIsotonicInput(labels), 2).cache(); + jsc.parallelize(generateIsotonicInput(labels), 2).cache(); return new IsotonicRegression().run(trainRDD); } @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLinearRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -70,7 +76,7 @@ public class JavaIsotonicRegressionSuite implements Serializable { runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); Assert.assertArrayEquals( - new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); + new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); } @Test @@ -78,7 +84,7 @@ public class JavaIsotonicRegressionSuite implements Serializable { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); + JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index 8950b48888..098bac3bed 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -28,24 +28,30 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; +import org.apache.spark.sql.SparkSession; public class JavaLassoSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLassoSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLassoSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List validationData, LassoModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); // A prediction is off if the prediction is more than 0.5 away from expected value. if (Math.abs(prediction - point.label()) <= 0.5) { @@ -61,15 +67,15 @@ public class JavaLassoSuite implements Serializable { double A = 0.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LassoWithSGD lassoSGDImpl = new LassoWithSGD(); lassoSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); + .setRegParam(0.01) + .setNumIterations(20); LassoModel model = lassoSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -82,10 +88,10 @@ public class JavaLassoSuite implements Serializable { double A = 0.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 24c4c20d9a..35087a5e46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -25,34 +25,40 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.util.LinearDataGenerator; +import org.apache.spark.sql.SparkSession; public class JavaLinearRegressionSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLinearRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List validationData, LinearRegressionModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } + for (LabeledPoint point : validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } } return numAccurate; } @@ -63,10 +69,10 @@ public class JavaLinearRegressionSuite implements Serializable { double A = 3.0; double[] weights = {10, 10}; - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); linSGDImpl.setIntercept(true); @@ -82,10 +88,10 @@ public class JavaLinearRegressionSuite implements Serializable { double A = 0.0; double[] weights = {10, 10}; - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); @@ -98,7 +104,7 @@ public class JavaLinearRegressionSuite implements Serializable { int nPoints = 100; double A = 0.0; double[] weights = {10, 10}; - JavaRDD testRDD = sc.parallelize( + JavaRDD testRDD = jsc.parallelize( LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index c56db703ea..b2efb2e72e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -29,25 +29,31 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; +import org.apache.spark.sql.SparkSession; public class JavaRidgeRegressionSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRidgeRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } private static double predictionError(List validationData, RidgeRegressionModel model) { double errorSum = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); errorSum += (prediction - point.label()) * (prediction - point.label()); } @@ -68,9 +74,9 @@ public class JavaRidgeRegressionSuite implements Serializable { public void runRidgeRegressionUsingConstructor() { int numExamples = 50; int numFeatures = 20; - List data = generateRidgeData(2*numExamples, numFeatures, 10.0); + List data = generateRidgeData(2 * numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + JavaRDD testRDD = jsc.parallelize(data.subList(0, numExamples)); List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); @@ -94,7 +100,7 @@ public class JavaRidgeRegressionSuite implements Serializable { int numFeatures = 20; List data = generateRidgeData(2 * numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + JavaRDD testRDD = jsc.parallelize(data.subList(0, numExamples)); List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 5f1d5987e8..373417d3ba 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -24,13 +24,11 @@ import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; - -import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; @@ -38,36 +36,42 @@ import org.apache.spark.mllib.stat.test.BinarySample; import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.sql.SparkSession; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; public class JavaStatisticsSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; private transient JavaStreamingContext ssc; @Before public void setUp() { SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("JavaStatistics") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - sc = new JavaSparkContext(conf); - ssc = new JavaStreamingContext(sc, new Duration(1000)); + spark = SparkSession.builder() + .master("local[2]") + .appName("JavaStatistics") + .config(conf) + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + ssc = new JavaStreamingContext(jsc, new Duration(1000)); ssc.checkpoint("checkpoint"); } @After public void tearDown() { + spark.stop(); ssc.stop(); - ssc = null; - sc = null; + spark = null; } @Test public void testCorr() { - JavaRDD x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); - JavaRDD y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); + JavaRDD x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); Double corr1 = Statistics.corr(x, y); Double corr2 = Statistics.corr(x, y, "pearson"); @@ -77,7 +81,7 @@ public class JavaStatisticsSuite implements Serializable { @Test public void kolmogorovSmirnovTest() { - JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); + JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( data, "norm", 0.0, 1.0); @@ -85,7 +89,7 @@ public class JavaStatisticsSuite implements Serializable { @Test public void chiSqTest() { - JavaRDD data = sc.parallelize(Arrays.asList( + JavaRDD data = jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 60585d2727..5b464a4722 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -35,25 +35,31 @@ import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.Strategy; import org.apache.spark.mllib.tree.impurity.Gini; import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.sql.SparkSession; public class JavaDecisionTreeSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaDecisionTreeSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List validationData, DecisionTreeModel model) { int numCorrect = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numCorrect++; @@ -65,7 +71,7 @@ public class JavaDecisionTreeSuite implements Serializable { @Test public void runDTUsingConstructor() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); - JavaRDD rdd = sc.parallelize(arr); + JavaRDD rdd = jsc.parallelize(arr); HashMap categoricalFeaturesInfo = new HashMap<>(); categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories @@ -73,7 +79,7 @@ public class JavaDecisionTreeSuite implements Serializable { int numClasses = 2; int maxBins = 100; Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, - maxBins, categoricalFeaturesInfo); + maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); DecisionTreeModel model = learner.run(rdd.rdd()); @@ -85,7 +91,7 @@ public class JavaDecisionTreeSuite implements Serializable { @Test public void runDTUsingStaticMethods() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); - JavaRDD rdd = sc.parallelize(arr); + JavaRDD rdd = jsc.parallelize(arr); HashMap categoricalFeaturesInfo = new HashMap<>(); categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories @@ -93,7 +99,7 @@ public class JavaDecisionTreeSuite implements Serializable { int numClasses = 2; int maxBins = 100; Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, - maxBins, categoricalFeaturesInfo); + maxBins, categoricalFeaturesInfo); DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1de638f245..55448325e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -183,7 +183,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("pipeline validateParams") { - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( Seq( (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 89afb94b0f..98116656ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -32,7 +32,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("extractLabeledPoints") { def getTestData(labels: Seq[Double]): DataFrame = { val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - sqlContext.createDataFrame(data) + spark.createDataFrame(data) } val c = new MockClassifier @@ -72,7 +72,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("getNumClasses") { def getTestData(labels: Seq[Double]): DataFrame = { val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - sqlContext.createDataFrame(data) + spark.createDataFrame(data) } val c = new MockClassifier diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 29845b5554..f94d336df5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -337,13 +337,13 @@ class DecisionTreeClassifierSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( - dt, isClassification = true, sqlContext) { (expected, actual) => + dt, isClassification = true, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } test("Fitting without numClasses in metadata") { - val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) val dt = new DecisionTreeClassifier().setMaxDepth(1) dt.fit(df) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 087e201234..c9453aaec2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( - gbt, isClassification = true, sqlContext) { (expected, actual) => + gbt, isClassification = true, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } @@ -130,7 +130,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext */ test("Fitting without numClasses in metadata") { - val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) gbt.fit(df) } @@ -138,7 +138,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("extractLabeledPoints with bad data") { def getTestData(labels: Seq[Double]): DataFrame = { val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - sqlContext.createDataFrame(data) + spark.createDataFrame(data) } val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 73e961dbbc..cb4d087ce5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -42,7 +42,7 @@ class LogisticRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) + dataset = spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) binaryDataset = { val nPoints = 10000 @@ -54,7 +54,7 @@ class LogisticRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - sqlContext.createDataFrame(sc.parallelize(testData, 4)) + spark.createDataFrame(sc.parallelize(testData, 4)) } } @@ -202,7 +202,7 @@ class LogisticRegressionSuite } test("logistic regression: Predictor, Classifier methods") { - val sqlContext = this.sqlContext + val spark = this.spark val lr = new LogisticRegression val model = lr.fit(dataset) @@ -864,8 +864,8 @@ class LogisticRegressionSuite } } - (sqlContext.createDataFrame(sc.parallelize(data1, 4)), - sqlContext.createDataFrame(sc.parallelize(data2, 4))) + (spark.createDataFrame(sc.parallelize(data1, 4)), + spark.createDataFrame(sc.parallelize(data2, 4))) } val trainer1a = (new LogisticRegression).setFitIntercept(true) @@ -938,7 +938,7 @@ class LogisticRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LogisticRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( - lr, isClassification = true, sqlContext) { (expected, actual) => + lr, isClassification = true, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients.toArray === actual.coefficients.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index f41db31f1e..876e047db5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -36,7 +36,7 @@ class MultilayerPerceptronClassifierSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame(Seq( + dataset = spark.createDataFrame(Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), @@ -77,7 +77,7 @@ class MultilayerPerceptronClassifierSuite } test("Test setWeights by training restart") { - val dataFrame = sqlContext.createDataFrame(Seq( + val dataFrame = spark.createDataFrame(Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), @@ -113,7 +113,7 @@ class MultilayerPerceptronClassifierSuite // the input seed is somewhat magic, to make this test pass val rdd = sc.parallelize(generateMultinomialLogisticInput( coefficients, xMean, xVariance, true, nPoints, 1), 2) - val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val dataFrame = spark.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 val layers = Array[Int](4, 5, 4, numClasses) @@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) MLTestingUtils.checkNumericTypes[ MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( - mpc, isClassification = true, sqlContext) { (expected, actual) => + mpc, isClassification = true, spark) { (expected, actual) => assert(expected.layers === actual.layers) assert(expected.weights === actual.weights) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 80a46fc70c..15d0358c3f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -43,7 +43,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) } def validatePrediction(predictionAndLabels: DataFrame): Unit = { @@ -127,7 +127,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) - val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val testDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 42, "multinomial")) val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) @@ -135,7 +135,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val validationDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") @@ -157,7 +157,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) - val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val testDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 45, "bernoulli")) val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) @@ -165,7 +165,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val validationDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") @@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("should support all NumericType labels and not support other types") { val nb = new NaiveBayes() MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( - nb, isClassification = true, sqlContext) { (expected, actual) => + nb, isClassification = true, spark) { (expected, actual) => assert(expected.pi === actual.pi) assert(expected.theta === actual.theta) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 51871a9bab..005d609307 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -53,7 +53,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) rdd = sc.parallelize(generateMultinomialLogisticInput( coefficients, xMean, xVariance, true, nPoints, 42), 2) - dataset = sqlContext.createDataFrame(rdd) + dataset = spark.createDataFrame(rdd) } test("params") { @@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("should support all NumericType labels and not support other types") { val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( - ovr, isClassification = true, sqlContext) { (expected, actual) => + ovr, isClassification = true, spark) { (expected, actual) => val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) assert(expectedModels.length === actualModels.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 90744353d9..97f3feacca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,7 +155,7 @@ class RandomForestClassifierSuite } test("Fitting without numClasses in metadata") { - val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) rf.fit(df) } @@ -189,7 +189,7 @@ class RandomForestClassifierSuite test("should support all NumericType labels and not support other types") { val rf = new RandomForestClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( - rf, isClassification = true, sqlContext) { (expected, actual) => + rf, isClassification = true, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 212ea7a0a9..4f7d4418a8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -30,7 +30,7 @@ class BisectingKMeansSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) } test("default parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 9d868174c1..04366f5250 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -32,7 +32,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) } test("default parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 241d21961f..2832db2f99 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) @@ -34,7 +34,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) } test("default parameters") { @@ -142,11 +142,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext + def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = spark.sparkContext val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) .map(v => new TestRow(v)) - sql.createDataFrame(rdd) + spark.createDataFrame(rdd) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 6cb07aecb9..34e8964286 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,30 +17,30 @@ package org.apache.spark.ml.clustering -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql._ object LDASuite { def generateLDAData( - sql: SQLContext, + spark: SparkSession, rows: Int, k: Int, vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc - val sc = sql.sparkContext + val sc = spark.sparkContext val rng = new java.util.Random() rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) - sql.createDataFrame(rdd) + spark.createDataFrame(rdd) } /** @@ -68,7 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead override def beforeAll(): Unit = { super.beforeAll() - dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) + dataset = LDASuite.generateLDAData(spark, 50, k, vocabSize) } test("default parameters") { @@ -140,7 +140,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - val dummyDF = sqlContext.createDataFrame(Seq( + val dummyDF = spark.createDataFrame(Seq( (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") // validate parameters lda.transformSchema(dummyDF.schema) @@ -274,7 +274,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead // There should be 1 checkpoint remaining. assert(model.getCheckpointFiles.length === 1) val checkpointFile = new Path(model.getCheckpointFiles.head) - val fs = checkpointFile.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration) assert(fs.exists(checkpointFile)) model.deleteCheckpointFiles() assert(model.getCheckpointFiles.isEmpty) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ff34522178..a8766f9035 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -42,21 +42,21 @@ class BinaryClassificationEvaluatorSuite val evaluator = new BinaryClassificationEvaluator() .setMetricName("areaUnderPR") - val vectorDF = sqlContext.createDataFrame(Seq( + val vectorDF = spark.createDataFrame(Seq( (0d, Vectors.dense(12, 2.5)), (1d, Vectors.dense(1, 3)), (0d, Vectors.dense(10, 2)) )).toDF("label", "rawPrediction") assert(evaluator.evaluate(vectorDF) === 1.0) - val doubleDF = sqlContext.createDataFrame(Seq( + val doubleDF = spark.createDataFrame(Seq( (0d, 0d), (1d, 1d), (0d, 0d) )).toDF("label", "rawPrediction") assert(evaluator.evaluate(doubleDF) === 1.0) - val stringDF = sqlContext.createDataFrame(Seq( + val stringDF = spark.createDataFrame(Seq( (0d, "0d"), (1d, "1d"), (0d, "0d") @@ -71,6 +71,6 @@ class BinaryClassificationEvaluatorSuite test("should support all NumericType labels and not support other types") { val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") - MLTestingUtils.checkNumericTypes(evaluator, sqlContext) + MLTestingUtils.checkNumericTypes(evaluator, spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 87e511a368..522f6675d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -38,6 +38,6 @@ class MulticlassClassificationEvaluatorSuite } test("should support all NumericType labels and not support other types") { - MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext) + MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index c7b9483069..dcc004358d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -42,7 +42,7 @@ class RegressionEvaluatorSuite * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) * .saveAsTextFile("path") */ - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) @@ -85,6 +85,6 @@ class RegressionEvaluatorSuite } test("should support all NumericType labels and not support other types") { - MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext) + MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 714b9db3aa..e91f758112 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -39,7 +39,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame( + val dataFrame: DataFrame = spark.createDataFrame( data.zip(defaultBinarized)).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() @@ -55,7 +55,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with setter") { val threshold: Double = 0.2 val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame( + val dataFrame: DataFrame = spark.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() @@ -71,7 +71,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + val dataFrame: DataFrame = spark.createDataFrame(Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) )).toDF("feature", "expected") @@ -88,7 +88,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with setter") { val threshold: Double = 0.2 val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + val dataFrame: DataFrame = spark.createDataFrame(Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) )).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 9ea7d43176..98b2316d78 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -39,7 +39,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validData = Array(-0.5, -0.3, 0.0, 0.2) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) val dataFrame: DataFrame = - sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -55,13 +55,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa // Check for exceptions when using a set of invalid feature values. val invalidData1: Array[Double] = Array(-0.9) ++ validData val invalidData2 = Array(0.51) ++ validData - val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") + val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF1).collect() } } - val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") + val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF2).collect() @@ -74,7 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) val dataFrame: DataFrame = - sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 7827db2794..4c6d9c5e26 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -24,14 +24,17 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[2]") + .appName("ChiSqSelectorSuite") + .getOrCreate() + import spark.implicits._ val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 7641e3b8cf..b82e3e90b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -35,7 +35,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext private def split(s: String): Seq[String] = s.split("\\s+") test("CountVectorizerModel common cases") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("a b b c d a"), @@ -55,7 +55,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer common cases") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), @@ -76,7 +76,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer vocabSize and minDF") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), @@ -118,7 +118,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a b b c c")), (1, split("aa bb cc"))) ).toDF("id", "words") @@ -132,7 +132,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF count") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq())), @@ -151,7 +151,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF freq") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), @@ -170,7 +170,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel and CountVectorizer with binary") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a a a b b b b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 36cafa290f..dbd5ae8345 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -63,7 +63,7 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val expectedResult = Vectors.dense(expectedResultBuffer) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( DCTTestData(data, expectedResult) )) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 44bad4aba4..89d67d8e6f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -34,7 +34,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, "a a b b c d".split(" ").toSeq) )).toDF("id", "words") val n = 100 @@ -54,7 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("applying binary term freqs") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, "a a b c c c".split(" ").toSeq) )).toDF("id", "words") val n = 100 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index bc958c1585..208ea84913 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -60,7 +60,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") @@ -86,7 +86,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 0d4e00668d..3409928007 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -59,7 +59,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("numeric interaction") { - val data = sqlContext.createDataFrame( + val data = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0)), (1, Vectors.dense(1.0, 5.0))) @@ -74,7 +74,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) @@ -90,7 +90,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("nominal interaction") { - val data = sqlContext.createDataFrame( + val data = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0)), (1, Vectors.dense(1.0, 5.0))) @@ -106,7 +106,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) @@ -126,7 +126,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("default attr names") { - val data = sqlContext.createDataFrame( + val data = spark.createDataFrame( Seq( (2, Vectors.dense(0.0, 4.0), 1.0), (1, Vectors.dense(1.0, 5.0), 10.0)) @@ -142,7 +142,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") val res = trans.transform(df) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index e083d47136..73d69ebfee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -36,7 +36,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(-1, -1)), Vectors.sparse(3, Array(0), Array(-0.75))) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val scaler = new MaxAbsScaler() .setInputCol("features") .setOutputCol("scaled") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 87206c777e..e495c8e571 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -38,7 +38,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(5, 5)), Vectors.sparse(3, Array(0), Array(-2.5))) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") @@ -57,7 +57,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { - val dummyDF = sqlContext.createDataFrame(Seq( + val dummyDF = spark.createDataFrame(Seq( (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") intercept[IllegalArgumentException] { val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index a9421e6825..e5288d9259 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -34,7 +34,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe val nGram = new NGram() .setInputCol("inputTokens") .setOutputCol("nGrams") - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array("Test", "for", "ngram", "."), Array("Test for", "for ngram", "ngram .") @@ -47,7 +47,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array("a", "b", "c", "d", "e"), Array("a b c d", "b c d e") @@ -60,7 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array(), Array() @@ -73,7 +73,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(6) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array("a", "b", "c", "d", "e"), Array() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 4688339019..241a1e9fb5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -61,7 +61,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 49803aef71..06ffbc386f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -32,7 +32,7 @@ class OneHotEncoderSuite def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -81,7 +81,7 @@ class OneHotEncoderSuite test("input column with ML attribute") { val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() .setInputCol("size") @@ -94,7 +94,7 @@ class OneHotEncoderSuite } test("input column without ML attribute") { - val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index f372ec5826..4befa84dbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -49,7 +49,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pc = mat.computePrincipalComponents(3) val expected = mat.multiply(pc).rows - val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") val pca = new PCA() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 86dbee1cf4..e3adbba9d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -59,7 +59,7 @@ class PolynomialExpansionSuite Vectors.sparse(19, Array.empty, Array.empty)) test("Polynomial expansion with default parameter") { - val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -76,7 +76,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with setter") { - val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -94,7 +94,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with degree 1 is identity on vectors") { - val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index f8476953d8..46e7495297 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -32,12 +32,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("transform numeric data") { val formula = new RFormula().setFormula("id ~ v1 + v2") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) @@ -50,7 +50,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") - val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") intercept[IllegalArgumentException] { formula.fit(original) } @@ -61,7 +61,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -70,7 +70,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists but is not double type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -82,7 +82,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") - val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -91,14 +91,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("allow empty label") { - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)) ).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), @@ -110,13 +110,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), @@ -129,13 +129,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), @@ -148,7 +148,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) @@ -165,7 +165,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation") { val formula = new RFormula().setFormula("id ~ vec") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) ).toDF("id", "vec") val model = formula.fit(original) @@ -181,7 +181,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation with unnamed input attrs") { val formula = new RFormula().setFormula("id ~ vec2") - val base = sqlContext.createDataFrame( + val base = spark.createDataFrame( Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) ).toDF("id", "vec") val metadata = new AttributeGroup( @@ -203,12 +203,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, 2, 4, 2), (2, 3, 4, 1)) ).toDF("a", "b", "c", "d") val model = formula.fit(original) val result = model.transform(original) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) @@ -223,12 +223,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor numeric interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -250,12 +250,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor factor interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), @@ -299,7 +299,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) ).toDF("id", "a", "b") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index e213e17d0d..1401ea9c4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -31,13 +31,13 @@ class SQLTransformerSuite } test("transform numeric data") { - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") val result = sqlTrans.transform(original) val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) .toDF("id", "v1", "v2", "v3", "v4") assert(result.schema.toString == resultSchema.toString) @@ -52,7 +52,7 @@ class SQLTransformerSuite } test("transformSchema") { - val df = sqlContext.range(10) + val df = spark.range(10) val outputSchema = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") .transformSchema(df.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 8c5e47a22c..d62301be14 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -73,7 +73,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with default parameter") { - val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") val standardScaler0 = new StandardScaler() .setInputCol("features") @@ -84,9 +84,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with setter") { - val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") - val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") - val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") + val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") + val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected") val standardScaler1 = new StandardScaler() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 8e7e000fbc..125ad02ebc 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { @@ -42,7 +42,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("test", "test"), Seq("test", "test")), (Seq("a", "b", "c", "d"), Seq("b", "c")), (Seq("a", "the", "an"), Seq()), @@ -60,7 +60,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("test", "test"), Seq()), (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "the", "an"), Seq()), @@ -77,7 +77,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setCaseSensitive(true) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("A"), Seq("A")), (Seq("The", "the"), Seq("The")) )).toDF("raw", "expected") @@ -98,7 +98,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("acaba", "ama", "biri"), Seq()), (Seq("hep", "her", "scala"), Seq("scala")) )).toDF("raw", "expected") @@ -112,7 +112,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("python", "scala", "a"), Seq("python", "scala", "a")), (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift")) )).toDF("raw", "expected") @@ -126,7 +126,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("python", "scala", "a"), Seq()), (Seq("Python", "Scala", "swift"), Seq("swift")) )).toDF("raw", "expected") @@ -148,7 +148,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol(outputCol) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("The", "the", "swift"), Seq("swift")) )).toDF("raw", outputCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0f3cdc841..c221d4aa55 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -39,7 +39,7 @@ class StringIndexerSuite test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -63,8 +63,8 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") - val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") + val df2 = spark.createDataFrame(data2).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -93,7 +93,7 @@ class StringIndexerSuite test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -114,12 +114,12 @@ class StringIndexerSuite val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") .setOutputCol("labelIndex") - val df = sqlContext.range(0L, 10L).toDF() + val df = spark.range(0L, 10L).toDF() assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) } test("StringIndexerModel can't overwrite output column") { - val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") val indexer = new StringIndexer() .setInputCol("input") .setOutputCol("output") @@ -153,7 +153,7 @@ class StringIndexerSuite test("IndexToString.transform") { val labels = Array("a", "b", "c") - val df0 = sqlContext.createDataFrame(Seq( + val df0 = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (0, "a") )).toDF("index", "expected") @@ -180,7 +180,7 @@ class StringIndexerSuite test("StringIndexer, IndexToString are inverses") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -213,7 +213,7 @@ class StringIndexerSuite test("StringIndexer metadata") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index 123ddfe42c..f30bdc3ddc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -57,13 +57,13 @@ class RegexTokenizerSuite .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = sqlContext.createDataFrame(Seq( + val dataset0 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) - val dataset1 = sqlContext.createDataFrame(Seq( + val dataset1 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) @@ -73,7 +73,7 @@ class RegexTokenizerSuite val tokenizer2 = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") - val dataset2 = sqlContext.createDataFrame(Seq( + val dataset2 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) @@ -85,7 +85,7 @@ class RegexTokenizerSuite .setInputCol("rawText") .setOutputCol("tokens") .setToLowercase(false) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), TokenizerTestData("java scala", Array("java", "scala")) )) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index dce994fdbd..250011c859 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -57,7 +57,7 @@ class VectorAssemblerSuite } test("VectorAssembler") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) )).toDF("id", "x", "y", "name", "z", "n") val assembler = new VectorAssembler() @@ -70,7 +70,7 @@ class VectorAssemblerSuite } test("transform should throw an exception in case of unsupported type") { - val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") @@ -87,7 +87,7 @@ class VectorAssemblerSuite NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), NumericAttribute.defaultAttr.withName("salary"))) val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) - val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") .select( col("browser").as("browser", browser.toMetadata()), col("hour").as("hour", hour.toMetadata()), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 1ffc62b38e..d1c0270a02 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) - sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) - densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) - sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) - badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) } private def getIndexer: VectorIndexer = @@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Cannot fit an empty DataFrame") { - val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer intercept[IllegalArgumentException] { vectorIndexer.fit(rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 6bb4678dc5..88a077f9a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } - val df = sqlContext.createDataFrame(rdd, + val df = spark.createDataFrame(rdd, StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 80c177b8d3..8cbe0f3def 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("Word2Vec") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size @@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 1704037395..9da0c32dee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -305,8 +305,8 @@ class ALSSuite numUserBlocks: Int = 2, numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -460,8 +460,8 @@ class ALSSuite allEstimatorParamSettings.foreach { case (p, v) => als.set(als.getParam(p), v) } - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val model = als.fit(ratings.toDF()) // Test Estimator save/load @@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite { // Generate test data val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) // Implicitly test the cleaning of parents during ALS training - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[2]") + .appName("ALSCleanerSuite") + .getOrCreate() + import spark.implicits._ val als = new ALS() .setRank(1) .setRegParam(1e-5) @@ -577,8 +580,8 @@ class ALSStorageSuite } test("default and non-default storage params set correct RDD StorageLevels") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val data = Seq( (0, 0, 1.0), (0, 1, 2.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 76891ad562..f8fc775676 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -37,13 +37,13 @@ class AFTSurvivalRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetUnivariate = sqlContext.createDataFrame( + datasetUnivariate = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) - datasetMultivariate = sqlContext.createDataFrame( + datasetMultivariate = spark.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) - datasetUnivariateScaled = sqlContext.createDataFrame( + datasetUnivariateScaled = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite test("should support all NumericType labels") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( - aft, isClassification = false, sqlContext) { (expected, actual) => + aft, isClassification = false, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index e9fb2677b2..d9f26ad8dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( - dt, isClassification = false, sqlContext) { (expected, actual) => + dt, isClassification = false, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 216377959e..f6ea5bb741 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -72,7 +72,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } test("GBTRegressor behaves reasonably on toy data") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), @@ -99,7 +99,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString sc.setCheckpointDir(path) - val df = sqlContext.createDataFrame(data) + val df = spark.createDataFrame(data) val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(5) @@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( - gbt, isClassification = false, sqlContext) { (expected, actual) => + gbt, isClassification = false, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index b854be2f1f..161f8c80f8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -52,19 +52,19 @@ class GeneralizedLinearRegressionSuite import GeneralizedLinearRegressionSuite._ - datasetGaussianIdentity = sqlContext.createDataFrame( + datasetGaussianIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gaussian", link = "identity"), 2)) - datasetGaussianLog = sqlContext.createDataFrame( + datasetGaussianLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gaussian", link = "log"), 2)) - datasetGaussianInverse = sqlContext.createDataFrame( + datasetGaussianInverse = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -80,40 +80,40 @@ class GeneralizedLinearRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - sqlContext.createDataFrame(sc.parallelize(testData, 2)) + spark.createDataFrame(sc.parallelize(testData, 2)) } - datasetPoissonLog = sqlContext.createDataFrame( + datasetPoissonLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log"), 2)) - datasetPoissonIdentity = sqlContext.createDataFrame( + datasetPoissonIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "identity"), 2)) - datasetPoissonSqrt = sqlContext.createDataFrame( + datasetPoissonSqrt = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "sqrt"), 2)) - datasetGammaInverse = sqlContext.createDataFrame( + datasetGammaInverse = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gamma", link = "inverse"), 2)) - datasetGammaIdentity = sqlContext.createDataFrame( + datasetGammaIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gamma", link = "identity"), 2)) - datasetGammaLog = sqlContext.createDataFrame( + datasetGammaLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -540,7 +540,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -668,7 +668,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), @@ -782,7 +782,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -899,7 +899,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -1021,14 +1021,14 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[ GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( - glr, isClassification = false, sqlContext) { (expected, actual) => + glr, isClassification = false, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } } test("glm accepts Dataset[LabeledPoint]") { - val context = sqlContext + val context = spark import context.implicits._ new GeneralizedLinearRegression() .setFamily("gaussian") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 3a10ad7ed0..9bf7542b12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -28,13 +28,13 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - sqlContext.createDataFrame( + spark.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } ).toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - sqlContext.createDataFrame(features.map(Tuple1.apply)) + spark.createDataFrame(features.map(Tuple1.apply)) .toDF("features") } @@ -145,7 +145,7 @@ class IsotonicRegressionSuite } test("vector features column with feature index") { - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( (4.0, Vectors.dense(0.0, 1.0)), (3.0, Vectors.dense(0.0, 2.0)), (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) @@ -184,7 +184,7 @@ class IsotonicRegressionSuite test("should support all NumericType labels and not support other types") { val ir = new IsotonicRegression() MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( - ir, isClassification = false, sqlContext) { (expected, actual) => + ir, isClassification = false, spark) { (expected, actual) => assert(expected.boundaries === actual.boundaries) assert(expected.predictions === actual.predictions) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index eb19d13093..10f547b673 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -42,7 +42,7 @@ class LinearRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetWithDenseFeature = sqlContext.createDataFrame( + datasetWithDenseFeature = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) @@ -50,7 +50,7 @@ class LinearRegressionSuite datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ - datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame( + datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) @@ -59,7 +59,7 @@ class LinearRegressionSuite // When feature size is larger than 4096, normal optimizer is choosed // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 - datasetWithSparseFeature = sqlContext.createDataFrame( + datasetWithSparseFeature = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, @@ -74,7 +74,7 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - datasetWithWeight = sqlContext.createDataFrame( + datasetWithWeight = spark.createDataFrame( sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -90,14 +90,14 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df.const.label <- as.data.frame(cbind(A, b.const)) */ - datasetWithWeightConstantLabel = sqlContext.createDataFrame( + datasetWithWeightConstantLabel = spark.createDataFrame( sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) - datasetWithWeightZeroLabel = sqlContext.createDataFrame( + datasetWithWeightZeroLabel = spark.createDataFrame( sc.parallelize(Seq( Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -828,8 +828,8 @@ class LinearRegressionSuite } val data2 = weightedSignedData ++ weightedNoiseData - (sqlContext.createDataFrame(sc.parallelize(data1, 4)), - sqlContext.createDataFrame(sc.parallelize(data2, 4))) + (spark.createDataFrame(sc.parallelize(data1, 4)), + spark.createDataFrame(sc.parallelize(data2, 4))) } val trainer1a = (new LinearRegression).setFitIntercept(true) @@ -1010,7 +1010,7 @@ class LinearRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, isClassification = false, sqlContext) { (expected, actual) => + lr, isClassification = false, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index ca400e1914..72f3c65eb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex test("should support all NumericType labels and not support other types") { val rf = new RandomForestRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( - rf, isClassification = false, sqlContext) { (expected, actual) => + rf, isClassification = false, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 1d7144f4e5..7d0e01fd8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -56,7 +56,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as sparse vector") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") val row1 = df.first() @@ -66,7 +66,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as dense vector") { - val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")) .load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") @@ -78,7 +78,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select a vector with specifying the longer dimension") { - val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + val df = spark.read.option("numFeatures", "100").format("libsvm") .load(path) val row1 = df.first() val v = row1.getAs[SparseVector](1) @@ -86,27 +86,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("write libsvm data and read it again") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) val tempDir2 = new File(tempDir, "read_write_test") val writepath = tempDir2.toURI.toString // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) - val df2 = sqlContext.read.format("libsvm").load(writepath) + val df2 = spark.read.format("libsvm").load(writepath) val row1 = df2.first() val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } test("write libsvm data failed due to invalid schema") { - val df = sqlContext.read.format("text").load(path) + val df = spark.read.format("text").load(path) intercept[SparkException] { df.write.format("libsvm").save(path + "_2") } } test("select features from libsvm relation") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) df.select("features").rdd.map { case Row(d: Vector) => d }.first df.select("features").collect } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index fecf372c3d..de92b51eb0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -37,8 +37,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val numIterations = 20 val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2) val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2) - val trainDF = sqlContext.createDataFrame(trainRdd) - val validateDF = sqlContext.createDataFrame(validateRdd) + val trainDF = spark.createDataFrame(trainRdd) + val validateDF = spark.createDataFrame(validateRdd) val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index e3f09899d7..12ade4c92f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} private[ml] object TreeTests extends SparkFunSuite { @@ -42,8 +42,12 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = SQLContext.getOrCreate(data.sparkContext) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[2]") + .appName("TreeTests") + .getOrCreate() + import spark.implicits._ + val df = data.toDF() val numFeatures = data.first().features.size val featuresAttributes = Range(0, numFeatures).map { feature => diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 061d04c932..85df6da7a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -39,7 +39,7 @@ class CrossValidatorSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame( + dataset = spark.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } @@ -67,7 +67,7 @@ class CrossValidatorSuite } test("cross validation with linear regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index df9ba418b8..f8d3de19b0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("train validation with logistic regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) val lr = new LogisticRegression @@ -58,7 +58,7 @@ class TrainValidationSplitSuite } test("train validation with linear regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d9e6fd5aae..4fe473bbac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -38,17 +38,17 @@ object MLTestingUtils extends SparkFunSuite { def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( estimator: T, isClassification: Boolean, - sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + spark: SparkSession)(check: (M, M) => Unit): Unit = { val dfs = if (isClassification) { - genClassifDFWithNumericLabelCol(sqlContext) + genClassifDFWithNumericLabelCol(spark) } else { - genRegressionDFWithNumericLabelCol(sqlContext) + genRegressionDFWithNumericLabelCol(spark) } val expected = estimator.fit(dfs(DoubleType)) val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) actuals.foreach(actual => check(expected, actual)) - val dfWithStringLabels = sqlContext.createDataFrame(Seq( + val dfWithStringLabels = spark.createDataFrame(Seq( ("0", Vectors.dense(0, 2, 3), 0.0) )).toDF("label", "features", "censor") val thrown = intercept[IllegalArgumentException] { @@ -58,13 +58,13 @@ object MLTestingUtils extends SparkFunSuite { "Column label must be of type NumericType but was actually of type StringType")) } - def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = { - val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction") + def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { + val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t))) actuals.foreach(actual => assert(expected === actual)) - val dfWithStringLabels = sqlContext.createDataFrame(Seq( + val dfWithStringLabels = spark.createDataFrame(Seq( ("0", 0d) )).toDF("label", "prediction") val thrown = intercept[IllegalArgumentException] { @@ -75,10 +75,10 @@ object MLTestingUtils extends SparkFunSuite { } def genClassifDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", featuresColName: String = "features"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Vectors.dense(0, 2, 3)), (1, Vectors.dense(0, 3, 1)), (0, Vectors.dense(0, 2, 2)), @@ -95,11 +95,11 @@ object MLTestingUtils extends SparkFunSuite { } def genRegressionDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", featuresColName: String = "features", censorColName: String = "censor"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Vectors.dense(0)), (1, Vectors.dense(1)), (2, Vectors.dense(2)), @@ -117,10 +117,10 @@ object MLTestingUtils extends SparkFunSuite { } def genEvaluatorDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", predictionColName: String = "prediction"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 0d), (1, 1d), (2, 2d), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 7f9e340f54..ba8d36f45f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -23,23 +23,22 @@ import org.scalatest.Suite import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils trait MLlibTestSparkContext extends TempDirectory { self: Suite => + @transient var spark: SparkSession = _ @transient var sc: SparkContext = _ - @transient var sqlContext: SQLContext = _ @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName("MLlibUnitTest") - sc = new SparkContext(conf) - SQLContext.clearActive() - sqlContext = new SQLContext(sc) - SQLContext.setActive(sqlContext) + spark = SparkSession.builder + .master("local[2]") + .appName("MLlibUnitTest") + .getOrCreate() + sc = spark.sparkContext + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString sc.setCheckpointDir(checkpointDir) } @@ -47,12 +46,11 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => override def afterAll() { try { Utils.deleteRecursively(new File(checkpointDir)) - sqlContext = null SQLContext.clearActive() - if (sc != null) { - sc.stop() + if (spark != null) { + spark.stop() } - sc = null + spark = null } finally { super.afterAll() } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 189cc3972c..f2ae40e644 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -28,14 +28,13 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType; // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - SparkContext context = new SparkContext("local[*]", "testing"); - javaCtx = new JavaSparkContext(context); - sqlContext = new SQLContext(context); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - javaCtx = null; + spark.stop(); + spark = null; } public static class Person implements Serializable { @@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD rowRDD = javaCtx.parallelize(personList).map( + JavaRDD rowRDD = jsc.parallelize(personList).map( new Function() { @Override public Row call(Person person) throws Exception { @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset df = sqlContext.createDataFrame(rowRDD, schema); + Dataset df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List actual = sqlContext.sql("SELECT * FROM people").collectAsList(); + List actual = spark.sql("SELECT * FROM people").collectAsList(); List expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD rowRDD = javaCtx.parallelize(personList).map( + JavaRDD rowRDD = jsc.parallelize(personList).map( new Function() { @Override public Row call(Person person) { @@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset df = sqlContext.createDataFrame(rowRDD, schema); + Dataset df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + List actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(new Function() { @Override public String call(Row row) { @@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable { @Test public void applySchemaToJSON() { - JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + JavaRDD jsonRDD = jsc.parallelize(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", @@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - Dataset df1 = sqlContext.read().json(jsonRDD); + Dataset df1 = spark.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); - List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + List actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset df2 = spark.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); - List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + List actual2 = spark.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1eb680dc4c..324ebbae38 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,12 +20,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; import java.net.URISyntaxException; import java.net.URL; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.ArrayList; +import java.util.*; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import org.junit.*; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.BloomFilter; import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; -import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } @Test public void testExecution() { - Dataset df = context.table("testData").filter("key = 1"); + Dataset df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test public void testCollectAndTake() { - Dataset df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset df = spark.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -83,7 +77,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - Dataset df = context.table("testData"); + Dataset df = spark.table("testData"); df.toDF("key1", "value1"); @@ -112,7 +106,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - Dataset df2 = context.table("testData2"); + Dataset df2 = spark.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -126,7 +120,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - Dataset df = context.table("testData"); + Dataset df = spark.table("testData"); df.show(); df.show(1000); } @@ -194,7 +188,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List data = Arrays.asList(bean); - Dataset df = context.createDataFrame(data, Bean.class); + Dataset df = spark.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -202,7 +196,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); - Dataset df = context.createDataFrame(rdd, Bean.class); + Dataset df = spark.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -210,7 +204,7 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List rows = Arrays.asList(RowFactory.create(0)); - Dataset df = context.createDataFrame(rows, schema); + Dataset df = spark.createDataFrame(rows, schema); List result = df.collectAsList(); Assert.assertEquals(1, result.size()); } @@ -239,7 +233,7 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); Dataset crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); @@ -258,7 +252,7 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); String[] cols = {"a"}; Dataset results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); @@ -266,21 +260,21 @@ public class JavaDataFrameSuite { @Test public void testCorrelation() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - Dataset df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); @@ -291,7 +285,7 @@ public class JavaDataFrameSuite { @Test public void pivot() { - Dataset df = context.table("courseSales"); + Dataset df = spark.table("courseSales"); List actual = df.groupBy("year") .pivot("course", Arrays.asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); @@ -324,10 +318,10 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - Dataset df1 = context.read().format("text").load(getResource("text-suite.txt")); + Dataset df1 = spark.read().format("text").load(getResource("text-suite.txt")); Assert.assertEquals(4L, df1.count()); - Dataset df2 = context.read().format("text").load( + Dataset df2 = spark.read().format("text").load( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, df2.count()); @@ -335,10 +329,10 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - Dataset ds1 = context.read().text(getResource("text-suite.txt")); + Dataset ds1 = spark.read().text(getResource("text-suite.txt")); Assert.assertEquals(4L, ds1.count()); - Dataset ds2 = context.read().text( + Dataset ds2 = spark.read().text( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, ds2.count()); @@ -346,7 +340,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - Dataset df = context.range(1000); + Dataset df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -371,7 +365,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - Dataset df = context.range(1000); + Dataset df = spark.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f1b1c22e4a..8354a5bdac 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,46 +23,43 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; -import com.google.common.base.Objects; -import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; +import com.google.common.base.Objects; import org.junit.*; +import org.junit.rules.ExpectedException; import org.apache.spark.Accumulator; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } private Tuple2 tuple2(T1 t1, T2 t2) { @@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); List collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); List collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testToLocalIterator() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); Iterator iter = ds.toLocalIterator(); Assert.assertEquals("hello", iter.next()); Assert.assertEquals("world", iter.next()); @@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset filtered = ds.filter(new FilterFunction() { @@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator accum = jsc.accumulator(0); List data = Arrays.asList("a", "b", "c"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction() { @Override @@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, Encoders.INT()); + Dataset ds = spark.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction() { @Override @@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset grouped = ds.groupByKey( new MapFunction() { @Override @@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { toSet(reduced.collectAsList())); List data2 = Arrays.asList(2, 6, 10); - Dataset ds2 = context.createDataset(data2, Encoders.INT()); + Dataset ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset grouped2 = ds2.groupByKey( new MapFunction() { @Override @@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List data = Arrays.asList(2, 6); - Dataset ds = context.createDataset(data, Encoders.INT()); + Dataset ds = spark.createDataset(data, Encoders.INT()); Dataset> selected = ds.select( expr("value + 1"), @@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List data = Arrays.asList("abc", "abc", "xyz"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset ds2 = context.createDataset(data2, Encoders.STRING()); + Dataset ds2 = spark.createDataset(data2, Encoders.STRING()); Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, Encoders.INT()).as("a"); + Dataset ds = spark.createDataset(data, Encoders.INT()).as("a"); List data2 = Arrays.asList(2, 3, 4); - Dataset ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + Dataset ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable { public void testTupleEncoder() { Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); - Dataset> ds2 = context.createDataset(data2, encoder2); + Dataset> ds2 = spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); Encoder> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List> data3 = Arrays.asList(new Tuple3<>(1, 2L, "a")); - Dataset> ds3 = context.createDataset(data3, encoder3); + Dataset> ds3 = spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List> data4 = Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); - Dataset> ds4 = context.createDataset(data4, encoder4); + Dataset> ds4 = spark.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder> encoder5 = @@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable { List> data5 = Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset> ds5 = - context.createDataset(data5, encoder5); + spark.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); } @@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); - Dataset, String>> ds = context.createDataset(data, encoder); + Dataset, String>> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); // test (int, (string, string, long)) @@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable { List>> data2 = Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset>> ds2 = - context.createDataset(data2, encoder2); + spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); // test (int, ((string, long), string)) @@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable { List, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset, String>>> ds3 = - context.createDataset(data3, encoder3); + spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } @@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable { 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset> ds = - context.createDataset(data, encoder); + spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable { Encoder encoder = Encoders.kryo(KryoSerializable.class); List data = Arrays.asList( new KryoSerializable("hello"), new KryoSerializable("world")); - Dataset ds = context.createDataset(data, encoder); + Dataset ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable { Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); List data = Arrays.asList( new JavaSerializable("hello"), new JavaSerializable("world")); - Dataset ds = context.createDataset(data, encoder); + Dataset ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @Test public void testRandomSplit() { List data = Arrays.asList("hello", "world", "from", "spark"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); double[] arraySplit = {1, 2, 3}; List> randomSplit = ds.randomSplitAsList(arraySplit, 1); @@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable { obj2.setF(Arrays.asList(300L, null, 400L)); List data = Arrays.asList(obj1, obj2); - Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Dataset ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List data2 = Arrays.asList(obj3); - Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Dataset ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow(new Object[]{ @@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable { .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); - Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + Dataset ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } @@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable { obj.setB(new Date(0)); obj.setC(java.math.BigDecimal.valueOf(1)); Dataset ds = - context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } @@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 4a78dca7fe..2274912521 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -24,33 +24,30 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaUDFSuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @SuppressWarnings("unchecked") @@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable { // sqlContext.registerFunction( // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF1() { + spark.udf().register("stringLengthTest", new UDF1() { @Override public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); } @@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable { // (String str1, String str2) -> str1.length() + str2.length, // DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF2() { + spark.udf().register("stringLengthTest", new UDF2() { @Override public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java index 7863177093..059c2d9f2c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java @@ -26,36 +26,30 @@ import scala.Tuple2; import org.junit.After; import org.junit.Before; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.test.TestSparkSession; /** * Common test base shared across this and Java8DatasetAggregatorSuite. */ public class JavaDatasetAggregatorSuiteBase implements Serializable { - protected transient JavaSparkContext jsc; - protected transient TestSQLContext context; + private transient TestSparkSession spark; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } protected Tuple2 tuple2(T1 t1, T2 t2) { @@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable { Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List> data = Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset> ds = context.createDataset(data, encoder); + Dataset> ds = spark.createDataset(data, encoder); return ds.groupByKey( new MapFunction, String>() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e65158eb0..d0435e4d43 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources; import java.io.File; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; @@ -37,8 +39,8 @@ import org.apache.spark.util.Utils; public class JavaSaveLoadSuite { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; File path; Dataset df; @@ -52,9 +54,11 @@ public class JavaSaveLoadSuite { @Before public void setUp() throws IOException { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -66,16 +70,15 @@ public class JavaSaveLoadSuite { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); + JavaRDD rdd = jsc.parallelize(jsonObjects); + df = spark.read().json(rdd); df.registerTempTable("jsonTable"); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @Test @@ -83,7 +86,7 @@ public class JavaSaveLoadSuite { Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - Dataset loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset loadedDF = spark.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -96,8 +99,8 @@ public class JavaSaveLoadSuite { List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset loadedDF = spark.read().format("json").schema(schema).options(options).load(); - checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList()); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 5ef20267f8..800316cde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -36,7 +36,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext import testImplicits._ def rddIdOf(tableName: String): Int = { - val plan = sqlContext.table(tableName).queryExecution.sparkPlan + val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => relation.cachedColumnBuffers.id @@ -73,41 +73,41 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - sqlContext.cacheTable("tempTable") + spark.catalog.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + spark.catalog.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != sqlContext.cacheManager.lookupCachedData(testData)) + assert(None != spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + spark.catalog.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - sqlContext.cacheTable("tempTable1") + spark.catalog.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - sqlContext.uncacheTable("tempTable2") + spark.catalog.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -117,101 +117,101 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val data = "*" * 1000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(sqlContext.table("bigData").count() === 200000L) - sqlContext.table("bigData").unpersist(blocking = true) + spark.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(spark.table("bigData").count() === 200000L) + spark.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - sqlContext.table("testData").cache() - assertCached(sqlContext.table("testData")) - sqlContext.table("testData").unpersist(blocking = true) + spark.table("testData").cache() + assertCached(spark.table("testData")) + spark.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - sqlContext.table("testData").cache() - sqlContext.table("testData").count() - sqlContext.table("testData").unpersist(blocking = true) - assertCached(sqlContext.table("testData"), 0) + spark.table("testData").cache() + spark.table("testData").count() + spark.table("testData").unpersist(blocking = true) + assertCached(spark.table("testData"), 0) } test("isCached") { - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") - assertCached(sqlContext.table("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + assertCached(spark.table("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - sqlContext.uncacheTable("testData") - assert(!sqlContext.isCached("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + spark.catalog.uncacheTable("testData") + assert(!spark.catalog.isCached("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - sqlContext.cacheTable("testData") - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r }.size } - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("read from cached table and uncache") { - sqlContext.cacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData")) - sqlContext.uncacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData"), 0) + spark.catalog.uncacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - sqlContext.cacheTable("selectStar") + spark.catalog.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - sqlContext.uncacheTable("selectStar") + spark.catalog.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -219,7 +219,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") + assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -228,14 +228,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(sqlContext.table("testCacheTable")) + assertCached(spark.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sqlContext.uncacheTable("testCacheTable") + spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -243,14 +243,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(sqlContext.table("testCacheTable")) + assertCached(spark.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sqlContext.uncacheTable("testCacheTable") + spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -258,7 +258,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -270,7 +270,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -278,7 +278,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -287,62 +287,62 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("Drops temporary table") { testData.select('key).registerTempTable("t1") - sqlContext.table("t1") - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) + spark.table("t1") + spark.catalog.dropTempTable("t1") + intercept[AnalysisException](spark.table("t1")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - sqlContext.cacheTable("t1") + spark.catalog.cacheTable("t1") - assert(sqlContext.isCached("t1")) - assert(sqlContext.isCached("t2")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) - assert(!sqlContext.isCached("t2")) + spark.catalog.dropTempTable("t1") + intercept[AnalysisException](spark.table("t1")) + assert(!spark.catalog.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") - sqlContext.clearCache() - assert(sqlContext.cacheManager.isEmpty) + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.clearCache() + assert(spark.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("Clear CACHE") - assert(sqlContext.cacheManager.isEmpty) + assert(spark.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - val accId1 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId1 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head - val accId2 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId2 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") assert(AccumulatorContext.get(accId1).isEmpty) assert(AccumulatorContext.get(accId2).isEmpty) @@ -351,7 +351,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") - sqlContext.cacheTable("abc") + spark.catalog.cacheTable("abc") val sparkPlan = sql( """select a.key, b.key, c.key from @@ -374,15 +374,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext table3x.registerTempTable("testData3x") sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") - sqlContext.cacheTable("orderedTable") - assertCached(sqlContext.table("orderedTable")) + spark.catalog.cacheTable("orderedTable") + assertCached(spark.table("orderedTable")) // Should not have an exchange as the query is already sorted on the group by key. verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) checkAnswer( sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) - sqlContext.uncacheTable("orderedTable") - sqlContext.dropTempTable("orderedTable") + spark.catalog.uncacheTable("orderedTable") + spark.catalog.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. @@ -390,8 +390,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(numPartitions, $"key").registerTempTable("t1") testData2.repartition(numPartitions, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") // Joining them should result in no exchanges. verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) @@ -403,8 +403,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), sql("SELECT count(*) FROM testData GROUP BY key")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } @@ -412,8 +412,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(6, $"key").registerTempTable("t1") testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -421,16 +421,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Need to shuffle one side. withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -438,15 +438,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(12, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -454,8 +454,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Since the number of partitions of @@ -464,30 +464,30 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 2) checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // repartition's column ordering is different from group by column ordering. // But they use the same set of columns. withTempTable("t1") { testData.repartition(6, $"value", $"key").registerTempTable("t1") - sqlContext.cacheTable("t1") + spark.catalog.cacheTable("t1") val query = sql("SELECT value, key from t1 group by key, value") verifyNumExchanges(query, 0) checkAnswer( query, testData.distinct().select($"value", $"key")) - sqlContext.uncacheTable("t1") + spark.catalog.uncacheTable("t1") } // repartition's column ordering is different from join condition's column ordering. @@ -499,8 +499,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext df1.repartition(6, $"value", $"key").registerTempTable("t1") val df2 = testData2.select($"a", $"b".cast("string")) df2.repartition(6, $"a", $"b").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") @@ -509,8 +509,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 19fe29a202..a5aecca13f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - sqlContext.createDataFrame(sparkContext.parallelize( + spark.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -287,7 +287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -308,7 +308,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -351,7 +351,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("=!=") { - val nullData = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData = spark.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -370,7 +370,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) - val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData2 = spark.createDataFrame(sparkContext.parallelize( Row("abc") :: Row(null) :: Row("xyz") :: Nil), @@ -596,7 +596,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) + val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 63f4b759a0..8a99866a33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -70,7 +70,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) ) - val decimalDataWithNulls = sqlContext.sparkContext.parallelize( + val decimalDataWithNulls = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, null) :: DecimalData(2, 1) :: @@ -114,7 +114,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 113000.0) :: Nil ) - val df0 = sqlContext.sparkContext.parallelize(Seq( + val df0 = spark.sparkContext.parallelize(Seq( Fact(20151123, 18, 35, "room1", 18.6), Fact(20151123, 18, 35, "room2", 22.4), Fact(20151123, 18, 36, "room1", 17.4), @@ -207,12 +207,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true) } test("agg without groups") { @@ -433,10 +433,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( - sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) checkAnswer( - sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0414fa1c91..031e66b57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -154,7 +154,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // SPARK-12275: no physical plan for BroadcastHint in some condition withTempPath { path => df1.write.parquet(path.getCanonicalPath) - val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + val pf1 = spark.read.parquet(path.getCanonicalPath) assert(df1.join(broadcast(pf1)).count() === 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index c6d67519b0..fa8fa06907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -81,11 +81,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ } test("pivot max values enforced") { - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, 1) intercept[AnalysisException]( courseSales.groupBy("year").pivot("course") ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } @@ -104,7 +104,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ // pivot with extra columns to trigger optimization .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) .agg(sum($"earnings")) - val queryExecution = sqlContext.executePlan(df.queryExecution.logical) + val queryExecution = spark.executePlan(df.queryExecution.logical) assert(queryExecution.simpleString.contains("pivotfirst")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0ea7727e45..ab7733b239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -236,7 +236,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("sampleBy") { - val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) + val df = spark.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), @@ -247,7 +247,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in // `CountMinSketchSuite` in project spark-sketch. test("countMinSketch") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) assert(sketch1.totalCount() === 1000) @@ -279,7 +279,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // This test only verifies some basic requirements, more correctness tests can be found in // `BloomFilterSuite` in project spark-sketch. test("Bloom filter") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val filter1 = df.stat.bloomFilter("id", 1000, 0.03) assert(filter1.expectedFpp() - 0.03 < 1e-3) @@ -304,7 +304,7 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin // Turn on this test if you want to test the performance of approximate quantiles. ignore("computing quantiles should not take much longer than describe()") { - val df = sqlContext.range(5000000L).toDF("col1").cache() + val df = spark.range(5000000L).toDF("col1").cache() def seconds(f: => Any): Double = { // Do some warmup logDebug("warmup...") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 80a93ee6d4..f77403c13e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -99,8 +99,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) val schema2 = StructType(Array(StructField("label", IntegerType, false), StructField("point", new ExamplePointUDT(), false))) - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) checkAnswer( df1.union(df2).orderBy("label"), @@ -109,8 +109,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("empty data frame") { - assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(sqlContext.emptyDataFrame.count() === 0) + assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(spark.emptyDataFrame.count() === 0) } test("head and take") { @@ -369,7 +369,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake checkAnswer( - sqlContext.range(2).toDF().limit(2147483638), + spark.range(2).toDF().limit(2147483638), Row(0) :: Row(1) :: Nil ) } @@ -672,12 +672,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val parquetDir = new File(dir, "parquet").getCanonicalPath df.write.parquet(parquetDir) - val parquetDF = sqlContext.read.parquet(parquetDir) + val parquetDF = spark.read.parquet(parquetDir) assert(parquetDF.inputFiles.nonEmpty) val jsonDir = new File(dir, "json").getCanonicalPath df.write.json(jsonDir) - val jsonDF = sqlContext.read.json(jsonDir) + val jsonDF = spark.read.json(jsonDir) assert(parquetDF.inputFiles.nonEmpty) val unioned = jsonDF.union(parquetDF).inputFiles.sorted @@ -801,7 +801,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.rdd.collect() } @@ -818,14 +818,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sparkContext.makeRDD( + val df2 = spark.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -881,53 +881,53 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = sqlContext.range(0, 10, 1, 15).select("id") + val res1 = spark.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = sqlContext.range(3, 15, 3, 2).select("id") + val res2 = spark.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = sqlContext.range(1, -2).select("id") + val res3 = spark.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = sqlContext.range(1, -2, -2, 6).select("id") + val res4 = spark.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = sqlContext.range(-3, -8, -2, 1).select("id") + val res5 = spark.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = sqlContext.range(-8, -4, 2, 1).select("id") + val res6 = spark.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = sqlContext.range(-10, -9, -20, 1).select("id") + val res7 = spark.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = sqlContext.range(10).select("id") + val res10 = spark.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = sqlContext.range(-1).select("id") + val res11 = spark.range(-1).select("id") assert(res11.count == 0) // using the default slice number - val res12 = sqlContext.range(3, 15, 3).select("id") + val res12 = spark.range(3, 15, 3).select("id") assert(res12.count == 4) assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } @@ -993,13 +993,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) + val jdf = spark.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -1019,7 +1019,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(spark, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } @@ -1062,7 +1062,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -1091,10 +1091,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(dir1, dir2), + checkAnswer(spark.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) - checkAnswer(sqlContext.read.format("json").load(dir1), + checkAnswer(spark.read.format("json").load(dir1), Row(1, 22) :: Nil) } } @@ -1116,7 +1116,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val input = spark.read.json(spark.sparkContext.makeRDD( (1 to 10).map(i => s"""{"id": $i}"""))) val df = input.select($"id", rand(0).as('r)) @@ -1185,7 +1185,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { withTempPath { path => Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) - val df = sqlContext.read.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) } } @@ -1244,7 +1244,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { verifyExchangingAgg(testData.repartition($"key", $"value") .groupBy("key").count()) - val data = sqlContext.sparkContext.parallelize( + val data = spark.sparkContext.parallelize( (1 to 100).map(i => TestData2(i % 10, i))).toDF() // Distribute and order by. @@ -1308,7 +1308,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withTempPath { path => val p = path.getAbsolutePath Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) - checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012)) } } } @@ -1317,7 +1317,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( rdd, new StructType().add("f1", IntegerType).add("f2", IntegerType), needsConversion = false).select($"F1", $"f2".as("f2")) @@ -1344,7 +1344,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) - sqlContext.udf.register("boxedUDF", + spark.udf.register("boxedUDF", (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) @@ -1393,7 +1393,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("reuse exchange") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan checkAnswer(join, df) @@ -1415,14 +1415,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("sameResult() on aggregate") { - val df = sqlContext.range(100) + val df = spark.range(100) val agg1 = df.groupBy().count() val agg2 = df.groupBy().count() // two aggregates with different ExprId within them should have same result assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) val agg3 = df.groupBy().sum() assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) - val df2 = sqlContext.range(101) + val df2 = spark.range(101) val agg4 = df2.groupBy().count() assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) } @@ -1454,24 +1454,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("assertAnalyzed shouldn't replace original stack trace") { val e = intercept[AnalysisException] { - sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) } assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) } test("SPARK-13774: Check error message for non existent path without globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("csv"). + val e = intercept[AnalysisException] (spark.read.format("csv"). load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage() assert(e.startsWith("Path does not exist")) } test("SPARK-13774: Check error message for not existent globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("text"). + val e = intercept[AnalysisException] (spark.read.format("text"). load( "/xyz/*")).getMessage() assert(e.startsWith("Path does not exist")) - val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd). + val e1 = intercept[AnalysisException] (spark.read.json("/mnt/*/*-xyz.json").rdd). getMessage() assert(e1.startsWith("Path does not exist")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 06584ec21e..a957d5ba25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -249,14 +249,14 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B try { f(tableName) } finally { - sqlContext.dropTempTable(tableName) + spark.catalog.dropTempTable(tableName) } } test("time window in SQL with single string expression") { withTempTable { table => checkAnswer( - sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + spark.sql(s"""select window(time, "10 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), @@ -270,7 +270,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with two expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( @@ -285,7 +285,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with three expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 68e99d6a6b..fe6ba83b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -48,7 +48,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b3", FloatType) .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(struct)) } @@ -70,7 +70,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b5b", StringType)) .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index ae9fb80c68..d8e241c62f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -31,14 +31,14 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,17 +72,17 @@ object DatasetBenchmark { benchmark } - def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -130,13 +130,13 @@ object DatasetBenchmark { override def outputEncoder: Encoder[Long] = Encoders.scalaLong } - def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { - import sqlContext.implicits._ + def aggregate(spark: SparkSession, numRows: Long): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("aggregate", numRows) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD sum") { iter => rdd.aggregate(0L)(_ + _.l, _ + _) } @@ -157,15 +157,17 @@ object DatasetBenchmark { } def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) + val spark = SparkSession.builder + .master("local[*]") + .appName("Dataset benchmark") + .getOrCreate() val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(sqlContext, numRows, numChains) - val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) - val benchmark3 = aggregate(sqlContext, numRows) + val benchmark = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilter(spark, numRows, numChains) + val benchmark3 = aggregate(spark, numRows) /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 942cc09b6d..8c0906b746 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -39,7 +39,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + assert(spark.cacheManager.lookupCachedData(cached).isEmpty, "The Dataset should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -56,9 +56,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(joined, 2) ds1.unpersist() - assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds1).isEmpty, + "The Dataset ds1 should not be cached.") ds2.unpersist() - assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds2).isEmpty, + "The Dataset ds2 should not be cached.") } test("persist and then groupBy columns asKey, map") { @@ -73,8 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(agged.filter(_._1 == "b")) ds.unpersist() - assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds).isEmpty, "The Dataset ds should not be cached.") agged.unpersist() - assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + assert(spark.cacheManager.lookupCachedData(agged).isEmpty, + "The Dataset agged should not be cached.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3cb4e52c6d..3c8c862c22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -46,12 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("range") { - assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) } test("SPARK-12404: Datatype Helper Serializability") { @@ -472,7 +472,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14696: implicit encoders for boxed types") { - assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L) + assert(spark.range(1).map { i => i : java.lang.Long }.head == 0L) } test("SPARK-11894: Incorrect results are returned when using null") { @@ -510,8 +510,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) def buildDataset(rows: Row*): Dataset[NestedStruct] = { - val rowRDD = sqlContext.sparkContext.parallelize(rows) - sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + val rowRDD = spark.sparkContext.parallelize(rows) + spark.createDataFrame(rowRDD, schema).as[NestedStruct] } checkDataset( @@ -626,7 +626,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { - val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + val wideDF = spark.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) // Make sure the generated code for this plan can compile and execute. checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) } @@ -654,7 +654,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { dataset.join(actual, dataset("user") === actual("id")).collect() } - test("SPARK-15097: implicits on dataset's sqlContext can be imported") { + test("SPARK-15097: implicits on dataset's spark can be imported") { val dataset = Seq(1, 2, 3).toDS() checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } @@ -735,10 +735,10 @@ object JavaData { def apply(a: Int): JavaData = new JavaData(a) } -/** Used to test importing dataset.sqlContext.implicits._ */ +/** Used to test importing dataset.spark.implicits._ */ object DatasetTransform { def addOne(ds: Dataset[Int]): Dataset[Int] = { - import ds.sqlContext.implicits._ + import ds.sparkSession.implicits._ ds.map(_ + 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index b1987c6908..a41b465548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -51,7 +51,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { test("insert an extraStrategy") { try { - sqlContext.experimental.extraStrategies = TestStrategy :: Nil + spark.experimental.extraStrategies = TestStrategy :: Nil val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( @@ -62,7 +62,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { df.select("a", "b"), Row("so slow", 1)) } finally { - sqlContext.experimental.extraStrategies = Nil + spark.experimental.extraStrategies = Nil } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8cbad04e23..da567db5ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( @@ -112,7 +112,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // } test("broadcasted hash join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") Seq( ("SELECT * FROM testData join testData2 ON key = a", @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") sql("CACHE TABLE testData2") Seq( @@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -435,7 +435,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted existence join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { @@ -461,17 +461,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("cross join with broadcast") { sql("CACHE TABLE testData") - val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + val sizeInByteOfTestData = statisticSizeInByte(spark.table("testData")) // we set the threshold is greater than statistic of the cached table testData withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { - assert(statisticSizeInByte(sqlContext.table("testData2")) > - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData2")) > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) - assert(statisticSizeInByte(sqlContext.table("testData")) < - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData")) < + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 9f6c86a575..c88dfe5f24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -33,36 +33,36 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) } test("get all tables") { checkAnswer( - sqlContext.tables().filter("tableName = 'listtablessuitetable'"), + spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( sql("SHOW tables").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"), + spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( sql("show TABLES in default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { + Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -81,9 +81,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "listtablessuitetable") ) checkAnswer( - sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - sqlContext.dropTempTable("tables") + spark.catalog.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala new file mode 100644 index 0000000000..1732977ee5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach +import org.scalatest.Suite + +/** Manages a local `spark` {@link SparkSession} variable, correctly stopping it after each test. */ +trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => + + @transient var spark: SparkSession = _ + + override def beforeAll() { + super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + } + + override def afterEach() { + try { + resetSparkContext() + } finally { + super.afterEach() + } + } + + def resetSparkContext(): Unit = { + LocalSparkSession.stop(spark) + spark = null + } + +} + +object LocalSparkSession { + def stop(spark: SparkSession) { + if (spark != null) { + spark.stop() + } + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSparkSession[T](sc: SparkSession)(f: SparkSession => T): T = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index df8b3b7d87..a1a9b66c1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { - protected def sqlContext: SQLContext + protected def spark: SparkSession // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -267,7 +267,7 @@ abstract class QueryTest extends PlanTest { val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) } catch { case NonFatal(e) => fail( @@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest { def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession) + LogicalRDD(l.output, origin.rdd)(spark) case l: LocalRelation => val origin = localRelations.pop() l.copy(data = origin.data) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1ff288cd19..e401abef29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -57,7 +57,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + StringUtils.filterPattern(spark.sessionState.functionRegistry.listFunction(), pattern) .map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions("*")) @@ -88,7 +88,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-14415: All functions should have own descriptions") { - for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + for (f <- spark.sessionState.functionRegistry.listFunction()) { if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } @@ -102,7 +102,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - sqlContext.cacheTable("cachedData") + spark.catalog.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -193,7 +193,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sparkContext.parallelize( + spark.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -211,7 +211,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6201 IN type conversion") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11226 Skip empty line in json file") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) .registerTempTable("d") @@ -258,9 +258,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen") { // Prepare a table that we can group some rows. - sqlContext.table("testData") - .union(sqlContext.table("testData")) - .union(sqlContext.table("testData")) + spark.table("testData") + .union(spark.table("testData")) + .union(spark.table("testData")) .registerTempTable("testData3x") try { @@ -333,7 +333,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(null, null, 0) :: Nil) } finally { - sqlContext.dropTempTable("testData3x") + spark.catalog.dropTempTable("testData3x") } } @@ -1041,7 +1041,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SET commands semantics using sql()") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -1082,17 +1082,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql(s"SET $nonexistentKey"), Row(nonexistentKey, "") ) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("SET commands with illegal or inappropriate argument") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("apply schema") { @@ -1110,7 +1110,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = spark.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -1140,7 +1140,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df2 = spark.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -1165,7 +1165,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD3, schema2) + val df3 = spark.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1210,7 +1210,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = spark.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1226,7 +1226,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3371 Renaming a function expression with group by gives error") { - sqlContext.udf.register("len", (s: String) => s.length) + spark.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1409,7 +1409,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3483 Special chars in column names") { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - sqlContext.read.json(data).registerTempTable("records") + spark.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1450,15 +1450,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempTable("data") - sqlContext.read.json( + spark.read.json( sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1504,7 +1504,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1517,14 +1517,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1628,7 +1628,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1776,7 +1776,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // We don't support creating a temporary table while specifying a database intercept[AnalysisException] { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE db.t |USING parquet @@ -1787,7 +1787,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { }.getMessage // If you use backticks to quote the name then it's OK. - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE `db.t` |USING parquet @@ -1795,7 +1795,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | path '$path' |) """.stripMargin) - checkAnswer(sqlContext.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) } } @@ -1818,7 +1818,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("run sql directly on files") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() withTempPath(f => { df.write.json(f.getCanonicalPath) checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), @@ -1880,7 +1880,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11303: filter should not be pushed down into sample") { - val df = sqlContext.range(100) + val df = spark.range(100) List(true, false).foreach { withReplacement => val sampled = df.sample(withReplacement, 0.1, 1) val sampledOdd = sampled.filter("id % 2 != 0") @@ -2059,7 +2059,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // Identity udf that tracks the number of times it is called. val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { + spark.udf.register("testUdf", (x: Int) => { countAcc.++=(1) x }) @@ -2093,9 +2093,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "false") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index ddab918629..b489b74fec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) + val spark = SparkSession.builder.getOrCreate() + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 6fb1aca769..1ab562f873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -290,7 +290,7 @@ trait StreamTest extends QueryTest with Timeouts { verify(currentStream == null, "stream already running") lastStream = currentStream currentStream = - sqlContext + spark .streams .startQuery( StreamExecution.nextName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 6809f26968..c7b95c2683 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -281,7 +281,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val df = sqlContext.range(1) + val df = spark.range(1) checkAnswer( df.select(format_number(lit(5L), 4)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ec950332c5..427f24a9f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - sqlContext.dropTempTable("tmp_table") + spark.catalog.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - sqlContext.dropTempTable("test_table") + spark.catalog.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -88,22 +88,22 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - sqlContext.udf.register("strLenScala", (_: String).length) + spark.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - sqlContext.udf.register("random0", () => { Math.random()}) + spark.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) + spark.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -115,7 +115,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -134,7 +134,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -151,10 +151,10 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) - sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) - sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) - sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -173,7 +173,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + spark.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -182,27 +182,27 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + spark.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - sqlContext.udf.register("intExpected", (x: Int) => x) + spark.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } test("udf in different types") { - sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) - sqlContext.udf.register("decimalDataFunc", + spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + spark.udf.register("decimalDataFunc", (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) - sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) - sqlContext.udf.register("arrayDataFunc", + spark.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + spark.udf.register("arrayDataFunc", (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) - sqlContext.udf.register("mapDataFunc", + spark.udf.register("mapDataFunc", (data: scala.collection.Map[Int, String]) => { data }) - sqlContext.udf.register("complexDataFunc", + spark.udf.register("complexDataFunc", (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) checkAnswer( @@ -235,7 +235,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { - val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + val myUDF = spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) // Without the fix, this will fail because we fail to cast data type of b to string // because myUDF does not know its input data type. With the fix, this query should not diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a49aaa8b73..3057e016c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -94,7 +94,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT } test("UDTs and UDFs") { - sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -106,7 +106,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -118,7 +118,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.repartition(1).write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -146,7 +146,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) + val jsonRDD = spark.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -167,7 +167,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonDataset = sqlContext.read.schema(schema).json(stringRDD) + val jsonDataset = spark.read.schema(schema).json(stringRDD) .as[(Int, UDT.MyDenseVector)] checkDataset( jsonDataset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 01d485ce2d..70a00a43f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.TestSQLContext class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -251,7 +250,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } def withSQLContext( - f: SQLContext => Unit, + f: SparkSession => Unit, targetNumPostShufflePartitions: Int, minNumPostShufflePartitions: Option[Int]): Unit = { val sparkConf = @@ -272,9 +271,11 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") } - val sparkContext = new SparkContext(sparkConf) - val sqlContext = new TestSQLContext(sparkContext) - try f(sqlContext) finally sparkContext.stop() + + val spark = SparkSession.builder + .config(sparkConf) + .getOrCreate() + try f(spark) finally spark.stop() } Seq(Some(3), None).foreach { minNumPostShufflePartitions => @@ -284,9 +285,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 20 as key", "id as value") val agg = df.groupBy("key").count @@ -294,7 +295,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. checkAnswer( agg, - sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect()) + spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -325,13 +326,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: join operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -339,10 +340,10 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "id as value") - .union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) checkAnswer( join, expectedAnswer.collect()) @@ -376,16 +377,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 1$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") .groupBy("key2") @@ -396,7 +397,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 500) .selectExpr("id", "2 as cnt") checkAnswer( @@ -428,16 +429,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 2$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -448,7 +449,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "2 as cnt", "id as value") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index ba16810cee..36cde3233d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -50,7 +50,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) @@ -75,7 +75,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("ShuffleExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3b2911d056..d2e1ea12fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.sessionState.planner + val planner = spark.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -78,7 +78,7 @@ class PlannerSuite extends SharedSQLContext { val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = sparkContext.parallelize(row :: Nil) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + spark.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -136,7 +136,7 @@ class PlannerSuite extends SharedSQLContext { sql("CACHE TABLE tiny") val a = testData.as("a") - val b = sqlContext.table("tiny").as("b") + val b = spark.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoinExec => join } @@ -145,7 +145,7 @@ class PlannerSuite extends SharedSQLContext { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") - sqlContext.clearCache() + spark.catalog.clearCache() } } } @@ -154,8 +154,8 @@ class PlannerSuite extends SharedSQLContext { withTempPath { file => val path = file.getCanonicalPath testData.write.parquet(path) - val df = sqlContext.read.parquet(path) - sqlContext.registerDataFrameAsTable(df, "testPushed") + val df = spark.read.parquet(path) + spark.wrapped.registerDataFrameAsTable(df, "testPushed") withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan @@ -295,7 +295,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -315,7 +315,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -333,7 +333,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -353,7 +353,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -376,7 +376,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -392,7 +392,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -408,7 +408,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -425,7 +425,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -444,7 +444,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -464,7 +464,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -493,7 +493,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -510,7 +510,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) + val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index c9f517ca34..ad41111bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.Properties import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { @@ -50,16 +50,19 @@ class SQLExecutionSuite extends SparkFunSuite { } test("concurrent query execution with fork-join pool (SPARK-13747)") { - val sc = new SparkContext("local[*]", "test") - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[*]") + .appName("test") + .getOrCreate() + + import spark.implicits._ try { // Should not throw IllegalArgumentException (1 to 100).par.foreach { _ => - sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() + spark.sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } } finally { - sc.stop() + spark.sparkContext.stop() } } @@ -67,8 +70,8 @@ class SQLExecutionSuite extends SparkFunSuite { * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. */ private def testConcurrentQueryExecution(sc: SparkContext): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder.getOrCreate() + import spark.implicits._ // Initialize local properties. This is necessary for the test to pass. sc.getLocalProperties diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 073e0b3f00..d7eae21f9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.test.SQLTestUtils @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def sqlContext: SQLContext + protected def spark: SparkSession /** * Runs the plan and makes sure the answer matches the expected result. @@ -90,9 +90,10 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { - case Some(errorMessage) => fail(errorMessage) - case None => + SparkPlanTest + .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match { + case Some(errorMessage) => fail(errorMessage) + case None => } } @@ -114,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -141,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -162,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -202,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -230,8 +231,8 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - val execution = new QueryExecution(sqlContext.sparkSession, null) { + private def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = { + val execution = new QueryExecution(spark.sparkSession, null) { override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 233104ae84..ada60f6919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("range/filter should be combined") { - val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) assert(df.collect() === Array(Row(2))) } test("Aggregate should be included in WholeStageCodegen") { - val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -44,7 +44,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy("id").count().orderBy("id") val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -53,10 +53,10 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastHashJoin should be included in WholeStageCodegen") { - val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) val schema = new StructType().add("k", IntegerType).add("v", StringType) - val smallDF = sqlContext.createDataFrame(rdd, schema) - val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + val smallDF = spark.createDataFrame(rdd, schema) + val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) assert(df.queryExecution.executedPlan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -64,7 +64,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Sort should be included in WholeStageCodegen") { - val df = sqlContext.range(3, 0, -1).toDF().sort(col("id")) + val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -75,7 +75,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("MapElements should be included in WholeStageCodegen") { import testImplicits._ - val ds = sqlContext.range(10).map(_.toString) + val ds = spark.range(10).map(_.toString) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -84,7 +84,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0) + val ds = spark.range(10).filter(_ % 2 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -93,7 +93,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("back-to-back typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 50c8745a28..88269a6a2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -42,14 +43,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // TODO: Improve this test when we have better statistics sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - sqlContext.cacheTable("sizeTst") + spark.catalog.cacheTable("sizeTst") assert( - sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - sqlContext.conf.autoBroadcastJoinThreshold) + spark.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan + val plan = spark.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("repeatedData") + spark.catalog.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("nullableRepeatedData") + spark.catalog.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -97,7 +98,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - sqlContext.cacheTable("timestamps") + spark.catalog.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -109,7 +110,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("withEmptyParts") + spark.catalog.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -178,35 +179,35 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + spark.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan + spark.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - sqlContext.isCached("InMemoryCache_different_data_types"), + spark.catalog.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - sqlContext.table("InMemoryCache_different_data_types").collect()) - sqlContext.dropTempTable("InMemoryCache_different_data_types") + spark.table("InMemoryCache_different_data_types").collect()) + spark.catalog.dropTempTable("InMemoryCache_different_data_types") } test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { - val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + val df = spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") val cached = df.cache() // count triggers the caching action. It should not throw. cached.count() // Make sure, the DataFrame is indeed cached. - assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + assert(spark.cacheManager.lookupCachedData(cached).nonEmpty) // Check result. checkAnswer( cached, - sqlContext.range(1, 100).selectExpr("id % 10 as id") + spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") ) @@ -215,7 +216,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") { - val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s") + val data = spark.range(10).selectExpr("id", "cast(id as string) as s") data.cache() assert(data.count() === 10) assert(data.filter($"s" === "3").count() === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9164074a3e..48c798986b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -32,23 +32,24 @@ class PartitionBatchPruningSuite import testImplicits._ - private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) + private lazy val originalInMemoryPartitionPruning = + spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, 10) // Enable in-memory partition pruning - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, true) // Enable in-memory table scan accumulators - sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + spark.conf.set("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { try { - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, originalColumnBatchSize) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, originalInMemoryPartitionPruning) } finally { super.afterAll() } @@ -63,12 +64,12 @@ class PartitionBatchPruningSuite TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") - sqlContext.cacheTable("pruningData") + spark.catalog.cacheTable("pruningData") } override protected def afterEach(): Unit = { try { - sqlContext.uncacheTable("pruningData") + spark.catalog.uncacheTable("pruningData") } finally { super.afterEach() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3586ddf7b6..5fbab2382a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,7 +37,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test - sqlContext.sessionState.catalog.reset() + spark.sessionState.catalog.reset() } finally { super.afterEach() } @@ -66,7 +66,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", sqlContext.conf.warehousePath, Map()), ignoreIfExists = false) + CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + ignoreIfExists = false) } private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { @@ -111,7 +112,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("the qualified path of a database is stored in the catalog") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog withTempDir { tmpDir => val path = tmpDir.toString @@ -274,7 +275,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => val dbNameWithoutBackTicks = cleanIdentifier(dbName) - assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + assert(!spark.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") @@ -334,7 +335,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create table in default db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", None) createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) @@ -343,7 +344,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create table in a specific db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog createDatabase(catalog, "dbx") val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) @@ -352,7 +353,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", Some("dbx")) val tableIdent2 = TableIdentifier("tab2", Some("dbx")) createDatabase(catalog, "dbx") @@ -444,7 +445,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: set properties") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -471,7 +472,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: unset properties") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -512,7 +513,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: bucketing is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -523,7 +524,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: skew is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -560,7 +561,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename partition") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") @@ -661,7 +662,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop table - temporary table") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql( """ |CREATE TEMPORARY TABLE tab1 @@ -686,7 +687,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropTable(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -705,7 +706,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // SQLContext does not support create view. Log an error message, if tab1 does not exists sql("DROP VIEW tab1") - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -726,7 +727,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetLocation(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1") createDatabase(catalog, "dbx") @@ -784,7 +785,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetSerde(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -830,7 +831,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testAddPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") @@ -880,7 +881,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index ac2af77a6e..52dda8c6ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -281,7 +281,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi )) val fakeRDD = new FileScanRDD( - sqlContext.sparkSession, + spark, (file: PartitionedFile) => Iterator.empty, Seq(partition) ) @@ -399,7 +399,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi util.stringToFile(file, "*" * size) } - val df = sqlContext.read + val df = spark.read .format(classOf[TestFileFormat].getName) .load(tempDir.getCanonicalPath) @@ -409,7 +409,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) } - Dataset.ofRows(sqlContext.sparkSession, bucketed) + Dataset.ofRows(spark, bucketed) } else { df } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 297731c70c..89d57653ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -27,7 +27,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { test("sizeInBytes should be the total size of all files") { withTempDir{ dir => dir.delete() - sqlContext.range(1000).write.parquet(dir.toString) + spark.range(1000).write.parquet(dir.toString) // ignore hidden files val allFiles = dir.listFiles(new FilenameFilter { override def accept(dir: File, name: String): Boolean = { @@ -35,7 +35,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { } }) val totalSize = allFiles.map(_.length()).sum - val df = sqlContext.read.parquet(dir.toString) + val df = spark.read.parquet(dir.toString) assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 28e59055fa..b6cdc8cfab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -91,7 +91,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -101,7 +101,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with calling another function to load") { - val cars = sqlContext + val cars = spark .read .option("header", "false") .csv(testFile(carsFile)) @@ -110,7 +110,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with type inference") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -121,7 +121,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test inferring booleans") { - val result = sqlContext.read + val result = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -133,7 +133,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with alternative delimiter and quote") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) .load(testFile(carsAltFile)) @@ -142,7 +142,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("parse unescaped quotes with maxCharsPerColumn") { - val rows = sqlContext.read + val rows = spark.read .format("csv") .option("maxCharsPerColumn", "4") .load(testFile(unescapedQuotesFile)) @@ -154,7 +154,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("bad encoding name") { val exception = intercept[UnsupportedCharsetException] { - sqlContext + spark .read .format("csv") .option("charset", "1-9588-osi") @@ -166,7 +166,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test different encoding") { // scalastyle:off - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsFile8859)}", header "true", @@ -174,12 +174,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) // scalastyle:on - verifyCars(sqlContext.table("carsTable"), withHeader = true) + verifyCars(spark.table("carsTable"), withHeader = true) } test("test aliases sep and encoding for delimiter and charset") { // scalastyle:off - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -192,17 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with tab separated file") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") """.stripMargin.replaceAll("\n", " ")) - verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + verifyCars(spark.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) } test("DDL test parsing decimal type") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, priceTag decimal, @@ -212,11 +212,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) assert( - sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + spark.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) } test("test for DROPMALFORMED parsing mode") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -226,7 +226,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test for FAILFAST parsing mode") { val exception = intercept[SparkException]{ - sqlContext.read + spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() @@ -236,7 +236,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for tokens more than the fields in the schema") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -247,7 +247,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with null quote character") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .option("quote", "") @@ -258,7 +258,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with empty file and known schema") { - val result = sqlContext.read + val result = spark.read .format("csv") .schema(StructType(List(StructField("column", StringType, false)))) .load(testFile(emptyFile)) @@ -268,25 +268,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with empty file") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, grp string) |USING csv |OPTIONS (path "${testFile(emptyFile)}", header "false") """.stripMargin.replaceAll("\n", " ")) - assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + assert(spark.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) } test("DDL test with schema") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, blank string) |USING csv |OPTIONS (path "${testFile(carsFile)}", header "true") """.stripMargin.replaceAll("\n", " ")) - val cars = sqlContext.table("carsTable") + val cars = spark.table("carsTable") verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) assert( cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) @@ -295,7 +295,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -304,7 +304,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .csv(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -316,7 +316,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with quote") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -327,7 +327,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("quote", "\"") .save(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .option("quote", "\"") @@ -338,7 +338,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false")) .load(testFile(commentsFile)) @@ -353,7 +353,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("inferring schema with commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) .load(testFile(commentsFile)) @@ -372,7 +372,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "true", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .load(testFile(datesFile)) @@ -393,7 +393,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "false", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .schema(customSchema) @@ -416,7 +416,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("setting comment to null disables comment support") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "", "header" -> "false")) .load(testFile(disableCommentsFile)) @@ -439,7 +439,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("model", StringType, nullable = false), StructField("comment", StringType, nullable = true), StructField("blank", StringType, nullable = true))) - val cars = sqlContext.read + val cars = spark.read .format("csv") .schema(dataSchema) .options(Map("header" -> "true", "nullValue" -> "null")) @@ -454,7 +454,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -468,7 +468,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -486,7 +486,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ) withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -502,7 +502,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -513,7 +513,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Schema inference correctly identifies the datatype when data is sparse.") { - val df = sqlContext.read + val df = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -525,7 +525,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("old csv data source name works") { - val cars = sqlContext + val cars = spark .read .format("com.databricks.spark.csv") .option("header", "false") @@ -535,7 +535,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("nulls, NaNs and Infinity values can be parsed") { - val numbers = sqlContext + val numbers = spark .read .format("csv") .schema(StructType(List( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 1742df31bb..c31dffedbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -27,16 +27,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowComments off") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowComments on") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowComments", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowComments", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -44,16 +44,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowSingleQuotes off") { val str = """{'name': 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowSingleQuotes", "false").json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowSingleQuotes on") { val str = """{'name': 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -61,16 +61,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowUnquotedFieldNames off") { val str = """{name: 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowUnquotedFieldNames on") { val str = """{name: 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowUnquotedFieldNames", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -78,16 +78,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowNumericLeadingZeros off") { val str = """{"age": 0018}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowNumericLeadingZeros on") { val str = """{"age": 0018}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowNumericLeadingZeros", "true").json(rdd) assert(df.schema.head.name == "age") assert(df.first().getLong(0) == 18) @@ -97,16 +97,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. ignore("allowNonNumericNumbers off") { val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } ignore("allowNonNumericNumbers on") { val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowNonNumericNumbers", "true").json(rdd) assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) @@ -114,16 +114,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowBackslashEscapingAnyCharacter off") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowBackslashEscapingAnyCharacter on") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.schema.last.name == "price") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b1279abd63..63fe4658d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -229,7 +229,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = sqlContext.read.json(jsonNullStruct) + val jsonDF = spark.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -248,7 +248,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val jsonDF = spark.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -276,7 +276,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = sqlContext.read.json(complexFieldAndType1) + val jsonDF = spark.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -375,7 +375,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = sqlContext.read.json(complexFieldAndType1) + val jsonDF = spark.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -391,7 +391,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = spark.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -463,7 +463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = spark.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -516,7 +516,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) + val jsonDF = spark.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -540,7 +540,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = sqlContext.read.json(arrayElementTypeConflict) + val jsonDF = spark.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -568,7 +568,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = sqlContext.read.json(missingFields) + val jsonDF = spark.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -588,7 +588,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -620,7 +620,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path) + val jsonDF = spark.read.option("primitivesAsString", "true").json(path) val expectedSchema = StructType( StructField("bigInteger", StringType, true) :: @@ -648,7 +648,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Loading a JSON dataset primitivesAsString returns complex fields as strings") { - val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1) + val jsonDF = spark.read.option("primitivesAsString", "true").json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -746,7 +746,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") { - val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType) + val jsonDF = spark.read.option("prefersDecimal", "true").json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -777,7 +777,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val mixedIntegerAndDoubleRecords = sparkContext.parallelize( """{"a": 3, "b": 1.1}""" :: s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) @@ -796,7 +796,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Infer big integers correctly even when it does not fit in decimal") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .json(bigIntegerRecords) // The value in `a` field will be a double as it does not fit in decimal. For `b` field, @@ -810,7 +810,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Infer floating-point values correctly even when it does not fit in decimal") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("prefersDecimal", "true") .json(floatingValueRecords) @@ -823,7 +823,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01))) - val mergedJsonDF = sqlContext.read + val mergedJsonDF = spark.read .option("prefersDecimal", "true") .json(floatingValueRecords ++ bigIntegerRecords) @@ -881,7 +881,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = sqlContext.read.schema(schema).json(path) + val jsonDF1 = spark.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -898,7 +898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = spark.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -919,7 +919,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = spark.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -947,7 +947,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = spark.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -973,7 +973,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = sqlContext.read.json(complexFieldAndType2) + val jsonDF = spark.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -991,7 +991,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = sqlContext.read.json(complexFieldAndType2) + val jsonDF = spark.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -1014,7 +1014,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = sqlContext.read.json(jsonArray) + val jsonDF = spark.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -1035,7 +1035,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("a", StringType, true) :: Nil) // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { - sqlContext.read + spark.read .option("mode", "FAILFAST") .json(corruptRecords) .collect() @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) val exceptionTwo = intercept[SparkException] { - sqlContext.read + spark.read .option("mode", "FAILFAST") .schema(schema) .json(corruptRecords) @@ -1060,7 +1060,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaTwo = StructType( StructField("a", StringType, true) :: Nil) // `DROPMALFORMED` mode should skip corrupt records - val jsonDFOne = sqlContext.read + val jsonDFOne = spark.read .option("mode", "DROPMALFORMED") .json(corruptRecords) checkAnswer( @@ -1069,7 +1069,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) assert(jsonDFOne.schema === schemaOne) - val jsonDFTwo = sqlContext.read + val jsonDFTwo = spark.read .option("mode", "DROPMALFORMED") .schema(schemaTwo) .json(corruptRecords) @@ -1083,7 +1083,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { withTempTable("jsonTable") { - val jsonDF = sqlContext.read.json(corruptRecords) + val jsonDF = spark.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("_unparsed", StringType, true) :: @@ -1134,7 +1134,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-13953 Rename the corrupt record field via option") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("columnNameOfCorruptRecord", "_malformed") .json(corruptRecords) val schema = StructType( @@ -1155,7 +1155,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-4068: nulls in arrays") { - val jsonDF = sqlContext.read.json(nullsInArrays) + val jsonDF = spark.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -1201,7 +1201,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = spark.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -1224,7 +1224,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD2, schema2) + val df3 = spark.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -1232,8 +1232,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = sqlContext.read.json(primitiveFieldAndType) - val primTable = sqlContext.read.json(jsonDF.toJSON.rdd) + val jsonDF = spark.read.json(primitiveFieldAndType) + val primTable = spark.read.json(jsonDF.toJSON.rdd) primTable.registerTempTable("primitiveTable") checkAnswer( sql("select * from primitiveTable"), @@ -1245,8 +1245,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = sqlContext.read.json(complexFieldAndType1) - val compTable = sqlContext.read.json(complexJsonDF.toJSON.rdd) + val complexJsonDF = spark.read.json(complexFieldAndType1) + val compTable = spark.read.json(complexJsonDF.toJSON.rdd) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1316,7 +1316,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = DataSource( - sqlContext.sparkSession, + spark, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, @@ -1324,7 +1324,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { options = Map("path" -> path)).resolveRelation() val d2 = DataSource( - sqlContext.sparkSession, + spark, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, @@ -1345,16 +1345,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempDir { dir => val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val df = spark.read.schema(schemaWithSimpleMap).json(mapType1) val path = dir.getAbsolutePath df.write.mode("overwrite").parquet(path) // order of MapType is not defined - assert(sqlContext.read.parquet(path).count() == 5) + assert(spark.read.parquet(path).count() == 5) - val df2 = sqlContext.read.json(corruptRecords) + val df2 = spark.read.json(corruptRecords) df2.write.mode("overwrite").parquet(path) - checkAnswer(sqlContext.read.parquet(path), df2.collect()) + checkAnswer(spark.read.parquet(path), df2.collect()) } } } @@ -1387,7 +1387,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "col1", "abd") - sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + spark.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) checkAnswer(sql( @@ -1447,7 +1447,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(4.75.toFloat, Seq(false, true)), new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = - Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil // Data generated by previous versions. // scalastyle:off @@ -1462,7 +1462,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // scalastyle:on // Generate data for the current version. - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) withTempPath { path => df.write.format("json").mode("overwrite").save(path.getCanonicalPath) @@ -1486,13 +1486,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "Spark 1.4.1", "Spark 1.5.0", "Spark 1.5.0", - "Spark " + sqlContext.sparkContext.version, - "Spark " + sqlContext.sparkContext.version) + "Spark " + spark.sparkContext.version, + "Spark " + spark.sparkContext.version) val expectedResult = col0Values.map { v => Row.fromSeq(Seq(v) ++ constantValues) } checkAnswer( - sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + spark.read.format("json").schema(schema).load(path.getCanonicalPath), expectedResult ) } @@ -1502,16 +1502,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(2) + val df = spark.range(2) df.write.json(path + "/p=1") df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) + assert(spark.read.json(path).count() === 4) val extraOptions = Map( "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName, "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName ) - assert(sqlContext.read.options(extraOptions).json(path).count() === 2) + assert(spark.read.options(extraOptions).json(path).count() === 2) } } @@ -1525,12 +1525,12 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { { // We need to make sure we can infer the schema. - val jsonDF = sqlContext.read.json(additionalCorruptRecords) + val jsonDF = spark.read.json(additionalCorruptRecords) assert(jsonDF.schema === schema) } { - val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) + val jsonDF = spark.read.schema(schema).json(additionalCorruptRecords) jsonDF.registerTempTable("jsonTable") // In HiveContext, backticks should be used to access columns starting with a underscore. @@ -1563,7 +1563,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("a", StructType( StructField("b", StringType) :: Nil )) :: Nil) - val jsonDF = sqlContext.read.schema(schema).json(path) + val jsonDF = spark.read.schema(schema).json(path) assert(jsonDF.count() == 2) } } @@ -1575,7 +1575,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") @@ -1585,7 +1585,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val compressedFiles = new File(jsonDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".json.gz"))) - val jsonCopy = sqlContext.read + val jsonCopy = spark.read .format("json") .load(jsonDir) @@ -1611,7 +1611,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") @@ -1622,7 +1622,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val compressedFiles = new File(jsonDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".json.gz"))) - val jsonCopy = sqlContext.read + val jsonCopy = spark.read .format("json") .options(extraOptions) .load(jsonDir) @@ -1637,7 +1637,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Casting long as timestamp") { withTempTable("jsonTable") { val schema = (new StructType).add("ts", TimestampType) - val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong) + val jsonDF = spark.read.schema(schema).json(timestampAsLong) jsonDF.registerTempTable("jsonTable") @@ -1657,8 +1657,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val json = s""" |{"a": [{$nested}], "b": [{$nested}]} """.stripMargin - val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.makeRDD(Seq(json)) + val df = spark.read.json(rdd) assert(df.schema.size === 2) df.collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 2873c6a881..f4a3336643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession private[json] trait TestJsonData { - protected def sqlContext: SQLContext + protected def spark: SparkSession def primitiveFieldAndType: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,7 +35,7 @@ private[json] trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,14 +46,14 @@ private[json] trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,14 +64,14 @@ private[json] trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -79,7 +79,7 @@ private[json] trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,7 +95,7 @@ private[json] trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,7 +149,7 @@ private[json] trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -157,7 +157,7 @@ private[json] trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,21 +166,21 @@ private[json] trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,7 +189,7 @@ private[json] trait TestJsonData { """]""" :: Nil) def additionalCorruptRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"dummy":"test"}""" :: """[1,2,3]""" :: """":"test", "a":1}""" :: @@ -197,7 +197,7 @@ private[json] trait TestJsonData { """ ","ian":"test"}""" :: Nil) def emptyRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -206,23 +206,23 @@ private[json] trait TestJsonData { """]""" :: Nil) def timestampAsLong: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"ts":1451732645}""" :: Nil) def arrayAndStructRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"a": {"b": 1}}""" :: """{"a": []}""" :: Nil) def floatingValueRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) def bigIntegerRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) - lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil) - def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index f98ea8c5ae..6509e04e85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -67,7 +67,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( i % 2 == 0, i, @@ -114,7 +114,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => if (i % 3 == 0) { Row.apply(Seq.fill(7)(null): _*) } else { @@ -155,7 +155,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( Seq.tabulate(3)(i => s"val_$i"), if (i % 3 == 0) null else Seq.tabulate(3)(identity)) @@ -182,7 +182,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) }) } @@ -205,7 +205,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) }) } @@ -221,7 +221,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( Seq.tabulate(3)(n => s"arr_${i + n}"), Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, @@ -267,7 +267,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared } } - checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + checkAnswer(spark.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 45cc6810d4..57cd70e191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -38,7 +38,7 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq } protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() val fsPath = new Path(path) val fs = fsPath.getFileSystem(hadoopConf) val parquetFiles = fs.listStatus(fsPath, new PathFilter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 65635e3c06..45fd6a5d80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -304,7 +304,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"), + spark.read.parquet(dir.getCanonicalPath).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -321,7 +321,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), + spark.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), (2 to 3).map(i => Row(i, i.toString, 1))) } } @@ -339,7 +339,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty, // this query will throw an exception since the project from combinedFilter expect // two projection while the - val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + val df1 = spark.read.parquet(dir.getCanonicalPath) assert(df1.filter("a > 1 or b < 2").count() == 2) } @@ -358,7 +358,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // test the generate new projection case // when projects != partitionAndNormalColumnProjs - val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + val df1 = spark.read.parquet(dir.getCanonicalPath) checkAnswer( df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"), @@ -381,7 +381,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). - val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") + val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( df, Row(1, "1", null)) @@ -394,7 +394,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex df.write.parquet(pathThree) // We will remove the temporary metadata when writing Parquet file. - val schema = sqlContext.read.parquet(pathThree).schema + val schema = spark.read.parquet(pathThree).schema assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) val pathFour = s"${dir.getCanonicalPath}/table4" @@ -407,7 +407,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "s.c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. - val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1") .selectExpr("s") checkAnswer(dfStruct3, Row(Row(null, 1))) @@ -420,7 +420,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex dfStruct3.write.parquet(pathSix) // We will remove the temporary metadata when writing Parquet file. - val forPathSix = sqlContext.read.parquet(pathSix).schema + val forPathSix = spark.read.parquet(pathSix).schema assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) // sanity test: make sure optional metadata field is not wrongly set. @@ -429,7 +429,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val pathEight = s"${dir.getCanonicalPath}/table8" (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) - val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") checkAnswer( df2, Row(1, "1")) @@ -449,7 +449,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") + val df = spark.read.parquet(path).filter("a = 2") // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. @@ -470,11 +470,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path) checkAnswer( - sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), + spark.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) checkAnswer( - sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"), + spark.read.parquet(path).where("not (a = 2 and b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) } } @@ -527,19 +527,19 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - val df = sqlContext.read.parquet(path).where("b in (0,2)") + val df = spark.read.parquet(path).where("b in (0,2)") assert(stripSparkFilter(df).count == 3) - val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + val df1 = spark.read.parquet(path).where("not (b in (1))") assert(stripSparkFilter(df1).count == 3) - val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)") assert(stripSparkFilter(df2).count == 2) - val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)") assert(stripSparkFilter(df3).count == 4) - val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + val df4 = spark.read.parquet(path).where("not (a <= 2)") assert(stripSparkFilter(df4).count == 3) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 32fe5ba127..d0107aae5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -113,7 +113,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) readParquetFile(path.toString)(df => { val sparkTypes = df.schema.map(_.dataType) @@ -132,7 +132,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { testStandardAndLegacyModes("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = { - sqlContext + spark .range(1000) // Parquet doesn't allow column names with spaces, have to add an alias here. // Minus 500 here so that negative decimals are also tested. @@ -250,10 +250,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) val errorMessage = intercept[Throwable] { - sqlContext.read.parquet(path.toString).printSchema() + spark.read.parquet(path.toString).printSchema() }.toString assert(errorMessage.contains("Parquet type not supported")) } @@ -271,15 +271,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) - val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + val sparkTypes = spark.read.parquet(path.toString).schema.map(_.dataType) assert(sparkTypes === expectedSparkTypes) } } test("compression codec") { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() def compressionCodecFor(path: String, codecName: String): String = { val codecs = for { footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) @@ -296,7 +296,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { compressionCodecFor(path, codec.name()) } } @@ -304,7 +304,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) + checkCompressionCodec( + CompressionCodecName.fromConf(spark.conf.get(SQLConf.PARQUET_COMPRESSION))) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -351,7 +352,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("write metadata") { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file => val path = new Path(file.toURI.toString) val fs = FileSystem.getLocal(hadoopConf) @@ -433,7 +434,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf, extraMetadata) readParquetFile(path.toString) { df => @@ -455,7 +456,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ) withTempPath { dir => val message = intercept[SparkException] { - sqlContext.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath) + spark.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(message === "Intentional exception for testing purposes") } @@ -465,10 +466,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// intercept[Throwable] { - sqlContext.read.parquet("file:///nonexistent") + spark.read.parquet("file:///nonexistent") } val errorMessage = intercept[Throwable] { - sqlContext.read.parquet("hdfs://nonexistent") + spark.read.parquet("hdfs://nonexistent") }.toString assert(errorMessage.contains("UnknownHostException")) } @@ -486,14 +487,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val m1 = intercept[SparkException] { - sqlContext.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) + spark.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m1.contains("Intentional exception for testing purposes")) } withTempPath { dir => val m2 = intercept[SparkException] { - val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + val df = spark.range(1).select('id as 'a, 'id as 'b).coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m2.contains("Intentional exception for testing purposes")) @@ -512,11 +513,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ParquetOutputFormat.ENABLE_DICTIONARY -> "true" ) - val hadoopConf = sqlContext.sessionState.newHadoopConfWithOptions(extraOptions) + val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" - sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") + spark.range(1 << 16).selectExpr("(id % 4) AS i") .coalesce(1).write.options(extraOptions).mode("overwrite").parquet(path) val blockMetadata = readFooter(new Path(path), hadoopConf).getBlocks.asScala.head @@ -531,7 +532,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data: Dataset[String] = sqlContext.range(200).map (i => + val data: Dataset[String] = spark.range(200).map (i => if (i < 150) null else "a" ) @@ -554,7 +555,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("dec-in-i32.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) } } } @@ -565,7 +566,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("dec-in-i64.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) } } } @@ -576,7 +577,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("dec-in-fixed-len.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } } } @@ -589,7 +590,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { var hash2: Int = 0 (false :: true :: Nil).foreach { v => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> v.toString) { - val df = sqlContext.read.parquet(dir.getCanonicalPath) + val df = spark.read.parquet(dir.getCanonicalPath) val rows = df.queryExecution.toRdd.map(_.copy()).collect() val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) if (!v) { @@ -607,7 +608,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("VectorizedParquetRecordReader - direct path read") { val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString)) withTempPath { dir => - sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { val reader = new VectorizedParquetRecordReader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 83b65fb419..9dc56292c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -81,7 +81,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS logParquetSchema(protobufStylePath) checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath), + spark.read.parquet(dir.getCanonicalPath), Seq( Row(Seq(0, 1)), Row(Seq(2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index b4d35be05d..8707e13461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -400,7 +400,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + spark.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -484,7 +484,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + spark.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -532,7 +532,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -572,7 +572,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -604,7 +604,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - sqlContext + spark .read .option("mergeSchema", "true") .format("parquet") @@ -622,7 +622,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: HadoopFsRelation, _, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -636,7 +636,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempPath { dir => val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + checkAnswer(spark.read.parquet(dir.getCanonicalPath), df.collect()) } } @@ -676,12 +676,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } } @@ -697,7 +697,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(dir.getCanonicalPath), df) } } @@ -714,7 +714,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } withTempPath { dir => @@ -731,7 +731,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } } @@ -746,7 +746,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha .save(tablePath.getCanonicalPath) val twoPartitionsDF = - sqlContext + spark .read .option("basePath", tablePath.getCanonicalPath) .parquet( @@ -756,7 +756,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkAnswer(twoPartitionsDF, df.filter("b != 3")) intercept[AssertionError] { - sqlContext + spark .read .parquet( s"${tablePath.getCanonicalPath}/b=1", @@ -829,7 +829,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1", "_SUCCESS")) Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1", "_SUCCESS")) Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1/d=1", "_SUCCESS")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } } } @@ -884,9 +884,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempPath { dir => withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { val path = dir.getCanonicalPath - val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + val df = spark.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) df.write.partitionBy("b", "c").parquet(path) - checkAnswer(sqlContext.read.parquet(path), df) + checkAnswer(spark.read.parquet(path), df) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f1e9726c38..f9f9f80352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -46,24 +46,24 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + spark.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") // Query appends, don't test with both read modes. withParquetTable(data, "t", false) { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + spark.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(spark.table("t"), data.map(Row.fromTuple)) } - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tmp"), ignoreIfNotExists = true) } @@ -128,9 +128,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = spark.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -139,12 +139,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -163,9 +163,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -181,19 +181,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + spark.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + spark.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -204,10 +204,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val basePath = dir.getCanonicalPath val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.write.parquet(basePath) - val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + val decimal = spark.read.parquet(basePath).first().getDecimal(0) assert(Decimal("67123.45") === Decimal(decimal)) } } @@ -227,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { checkAnswer( - sqlContext.read.option("mergeSchema", "true").parquet(path), + spark.read.option("mergeSchema", "true").parquet(path), Seq( Row(Row(1, 1, null)), Row(Row(2, 2, null)), @@ -240,7 +240,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - same schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -253,7 +253,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L))) } } @@ -261,12 +261,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-11997 parquet with null partition values") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(1, 3) + spark.range(1, 3) .selectExpr("if(id % 2 = 0, null, id) AS n", "id") .write.partitionBy("n").parquet(path) checkAnswer( - sqlContext.read.parquet(path).filter("n is null"), + spark.read.parquet(path).filter("n is null"), Row(2, null)) } } @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -288,7 +288,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(null, null))) } } @@ -296,7 +296,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -311,13 +311,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L, null, null))) } withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -332,7 +332,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, null, null, 3L))) } } @@ -340,7 +340,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) @@ -357,13 +357,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L))) } withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) @@ -380,7 +380,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 3L))) } } @@ -388,7 +388,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) @@ -406,7 +406,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(1L, 2L, null))) } } @@ -415,7 +415,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") .coalesce(1) @@ -436,7 +436,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(Seq(Row(0, null))))) } } @@ -445,12 +445,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df1 = sqlContext + val df1 = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) - val df2 = sqlContext + val df2 = spark .range(1, 2) .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") .coalesce(1) @@ -467,7 +467,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Seq( Row(Row(0, 1, null)), Row(Row(null, 2, 4)))) @@ -478,12 +478,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df1 = sqlContext + val df1 = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") .coalesce(1) - val df2 = sqlContext + val df2 = spark .range(1, 2) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) @@ -492,7 +492,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext df2.write.mode(SaveMode.Append).parquet(path) checkAnswer( - sqlContext + spark .read .option("mergeSchema", "true") .parquet(path) @@ -507,7 +507,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr( """NAMED_STRUCT( @@ -532,7 +532,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(NestedStruct(1, 2L, 3.5D)))) } } @@ -585,9 +585,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df) + checkAnswer(spark.read.parquet(path), df) } } @@ -595,11 +595,11 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withSQLConf("spark.sql.codegen.maxFields" -> "100") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // donot return batch, because whole stage codegen is disabled for wide table (>200 columns) - val df2 = sqlContext.read.parquet(path) + val df2 = spark.read.parquet(path) assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty, "Should not return batch") checkAnswer(df2, df) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cef541f044..373d3a3a0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -21,9 +21,9 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Try -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.{Benchmark, Utils} /** @@ -34,12 +34,16 @@ import org.apache.spark.util.{Benchmark, Utils} object ParquetReadBenchmark { val conf = new SparkConf() conf.set("spark.sql.parquet.compression.codec", "snappy") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + + val spark = SparkSession.builder + .master("local[1]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() // Set default configs. Individual cases will change them if necessary. - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() @@ -48,17 +52,17 @@ object ParquetReadBenchmark { } def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(spark.catalog.dropTempTable) } def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -71,18 +75,18 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as id from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } sqlBenchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } } @@ -155,20 +159,20 @@ object ParquetReadBenchmark { def intStringScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("Int and String Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } } @@ -189,20 +193,20 @@ object ParquetReadBenchmark { def stringDictionaryScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("String Dictionary", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } } @@ -221,23 +225,23 @@ object ParquetReadBenchmark { def partitionTableScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select id % 2 as p, cast(id as INT) as id from t1") .write.partitionBy("p").parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("Partitioned Table", values) benchmark.addCase("Read data column") { iter => - sqlContext.sql("select sum(id) from tempTable").collect + spark.sql("select sum(id) from tempTable").collect } benchmark.addCase("Read partition column") { iter => - sqlContext.sql("select sum(p) from tempTable").collect + spark.sql("select sum(p) from tempTable").collect } benchmark.addCase("Read both columns") { iter => - sqlContext.sql("select sum(p), sum(id) from tempTable").collect + spark.sql("select sum(p), sum(id) from tempTable").collect } /* @@ -256,16 +260,16 @@ object ParquetReadBenchmark { def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + + spark.range(values).registerTempTable("t1") + spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("String with Nulls Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " + + spark.sql("select sum(length(c2)) from tempTable where c1 is " + "not NULL and c2 is not NULL").collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 90e3d50714..c43b142de2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -453,11 +453,11 @@ class ParquetSchemaSuite extends ParquetSchemaTest { test("schema merging failure error message") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage assert(message.contains("Failed merging schema of file")) @@ -466,13 +466,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // test for second merging (after read Parquet schema in parallel done) withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") - sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + spark.sparkContext.conf.set("spark.default.parallelism", "20") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage assert(message.contains("Failed merging schema:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index e8c524e9e5..b5fc51603e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -52,7 +52,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (true :: false :: Nil).foreach { vectorized => if (!vectorized || testVectorized) { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - f(sqlContext.read.parquet(path.toString)) + f(spark.read.parquet(path.toString)) } } } @@ -66,7 +66,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + spark.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -90,14 +90,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { withParquetDataFrame(data, testVectorized) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + spark.wrapped.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + spark.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( @@ -173,6 +173,6 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def readResourceParquetFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) - sqlContext.read.parquet(url.toString) + spark.read.parquet(url.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 88a3d878f9..ff5706999a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -32,7 +32,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar |${readParquetSchema(parquetFilePath.toString)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + checkAnswer(spark.read.parquet(parquetFilePath.toString), (0 until 10).map { i => val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") val nonNullablePrimitiveValues = Seq( @@ -139,7 +139,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar logParquetSchema(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(Seq(Seq(0, 1), Seq(2, 3))), Row(Seq(Seq(4, 5), Seq(6, 7))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala index fd56265297..08b7eb3cf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.util.Benchmark @@ -36,9 +36,10 @@ object TPCDSBenchmark { conf.set("spark.driver.memory", "3g") conf.set("spark.executor.memory", "3g") conf.set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) + conf.setMaster("local[1]") + conf.setAppName("test-sql-context") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession.builder.config(conf).getOrCreate() // These queries a subset of the TPCDS benchmark queries and are taken from // https://github.com/databricks/spark-sql-perf/blob/master/src/main/scala/com/databricks/spark/ @@ -1186,8 +1187,8 @@ object TPCDSBenchmark { def setupTables(dataLocation: String): Map[String, Long] = { tables.map { tableName => - sqlContext.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) - tableName -> sqlContext.table(tableName).count() + spark.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) + tableName -> spark.table(tableName).count() }.toMap } @@ -1195,18 +1196,18 @@ object TPCDSBenchmark { require(dataLocation.nonEmpty, "please modify the value of dataLocation to point to your local TPCDS data") val tableSizes = setupTables(dataLocation) - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") tpcds.filter(q => q._1 != "").foreach { case (name: String, query: String) => - val numRows = sqlContext.sql(query).queryExecution.logical.map { + val numRows = spark.sql(query).queryExecution.logical.map { case ur@UnresolvedRelation(t: TableIdentifier, _) => tableSizes.getOrElse(t.table, throw new RuntimeException(s"${t.table} not found.")) case _ => 0L }.sum val benchmark = new Benchmark("TPCDS Snappy (scale = 5)", numRows, 5) benchmark.addCase(name) { i => - sqlContext.sql(query).collect() + spark.sql(query).collect() } benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 923c0b350e..f61fce5d41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -33,20 +33,20 @@ import org.apache.spark.util.Utils class TextSuite extends QueryTest with SharedSQLContext { test("reading text file") { - verifyFrame(sqlContext.read.format("text").load(testFile)) + verifyFrame(spark.read.format("text").load(testFile)) } test("SQLContext.read.text() API") { - verifyFrame(sqlContext.read.text(testFile).toDF()) + verifyFrame(spark.read.text(testFile).toDF()) } test("SPARK-12562 verify write.text() can handle column name beyond `value`") { - val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") + val df = spark.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() df.write.text(tempFile.getCanonicalPath) - verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath).toDF()) + verifyFrame(spark.read.text(tempFile.getCanonicalPath).toDF()) Utils.deleteRecursively(tempFile) } @@ -55,18 +55,18 @@ class TextSuite extends QueryTest with SharedSQLContext { val tempFile = Utils.createTempDir() tempFile.delete() - val df = sqlContext.range(2) + val df = spark.range(2) intercept[AnalysisException] { df.write.text(tempFile.getCanonicalPath) } intercept[AnalysisException] { - sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + spark.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) } } test("SPARK-13503 Support to specify the option for compression codec for TEXT") { - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz") extensionNameMap.foreach { case (codecName, extension) => @@ -75,7 +75,7 @@ class TextSuite extends QueryTest with SharedSQLContext { testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(_.getName.endsWith(s".txt$extension"))) - verifyFrame(sqlContext.read.text(tempDirPath).toDF()) + verifyFrame(spark.read.text(tempDirPath).toDF()) } val errMsg = intercept[IllegalArgumentException] { @@ -95,14 +95,14 @@ class TextSuite extends QueryTest with SharedSQLContext { "mapreduce.map.output.compress.codec" -> classOf[GzipCodec].getName ) withTempDir { dir => - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val tempDir = Utils.createTempDir() val tempDirPath = tempDir.getAbsolutePath testDf.write.option("compression", "none") .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) - verifyFrame(sqlContext.read.options(extraOptions).text(tempDirPath).toDF()) + verifyFrame(spark.read.options(extraOptions).text(tempDirPath).toDF()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8aa0114d98..4fc52c99fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -33,7 +33,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index b9df43d049..730ec43556 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.{QueryTest, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ @@ -34,7 +34,7 @@ import org.apache.spark.sql.functions._ * without serializing the hashed relation, which does not happen in local mode. */ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - protected var sqlContext: SQLContext = null + protected var spark: SparkSession = null /** * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. @@ -45,26 +45,26 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { .setMaster("local-cluster[2,1,1024]") .setAppName("testing") val sc = new SparkContext(conf) - sqlContext = new SQLContext(sc) + spark = SparkSession.builder.getOrCreate() } override def afterAll(): Unit = { - sqlContext.sparkContext.stop() - sqlContext = null + spark.stop() + spark = null } /** * Test whether the specified broadcast join updates the peak execution memory accumulator. */ private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { - val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) val plan = - EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) + EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 2a4a3690f2..7caeb3be54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -32,7 +32,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder - private lazy val myUpperCaseData = sqlContext.createDataFrame( + private lazy val myUpperCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), @@ -43,7 +43,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = sqlContext.createDataFrame( + private lazy val myLowerCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), @@ -99,7 +99,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) + EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin) } def makeShuffledHashJoin( @@ -113,7 +113,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) + EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) } def makeSortMergeJoin( @@ -124,7 +124,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) + EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c26cb8483e..001feb0f2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = sqlContext.createDataFrame( + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), @@ -42,7 +42,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = sqlContext.createDataFrame( + private lazy val right = spark.createDataFrame( sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -82,7 +82,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( ShuffledHashJoinExec( leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), @@ -115,7 +115,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d41e88a0aa..1b82769428 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -71,21 +71,21 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.listener.executionIdToData.keySet withSQLConf("spark.sql.codegen.wholeStage" -> "false") { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) @@ -128,7 +128,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) // TODO: update metrics in generated operators - val ds = sqlContext.range(10).filter('id < 5) + val ds = spark.range(10).filter('id < 5) testSparkPlanMetrics(ds.toDF(), 1, Map.empty) } @@ -157,7 +157,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = sqlContext.range(10).sort('id) + val ds = spark.range(10).sort('id) testSparkPlanMetrics(ds.toDF(), 2, Map.empty) } @@ -169,7 +169,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -187,7 +187,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -195,7 +195,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "number of output rows" -> 8L))) ) - val df2 = sqlContext.sql( + val df2 = spark.sql( "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -241,7 +241,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") testSparkPlanMetrics(df, 3, Map( @@ -269,7 +269,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin") testSparkPlanMetrics(df, 1, Map( 0L -> ("CartesianProduct", Map( @@ -280,19 +280,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq.exists(_ === "2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 7b413dda1e..a7b2cfe7d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -217,7 +217,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") { withFileStreamSinkLog { sinkLog => - val fs = sinkLog.metadataPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) def listBatchFiles(): Set[String] = { fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName => @@ -263,7 +263,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { withTempDir { file => - val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath) + val sinkLog = new FileStreamSinkLog(spark, file.getCanonicalPath) f(sinkLog) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 5f92c5bb9b..ef2b479a56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -59,63 +59,63 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) // Adding the same batch does nothing metadataLog.add(1, "batch1-duplicated") assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - sqlContext.conf.setConfString( + spark.conf.set( s"fs.$scheme.impl", classOf[FakeFileSystem].getName) withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) } } test("HDFSMetadataLog: restart") { withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog2 = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.get(1) === Some("batch1")) assert(metadataLog2.getLatest() === Some(1 -> "batch1")) - assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog2.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } @@ -127,7 +127,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { new Thread() { override def run(): Unit = waiter { val metadataLog = - new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + new HDFSMetadataLog[String](spark, temp.getAbsolutePath) try { var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) nextBatchId += 1 @@ -146,9 +146,10 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } waiter.await(timeout(10.seconds), dismissals(10)) - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) - assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) + assert( + metadataLog.get(None, Some(maxBatchId)) === (0 to maxBatchId).map(i => (i, i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6be94eb24f..4fa1754253 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -27,10 +27,11 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -54,19 +55,18 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - withSpark(new SparkContext(sparkConf)) { sc => - val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = - makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -79,30 +79,30 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( - sc: SparkContext, + spark: SparkSession, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + implicit val sqlContext = spark.wrapped + makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) } // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + require(makeStoreRDD(spark, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + assert(makeStoreRDD(spark, Seq("a"), 20).collect().toSet === Set("a" -> 21)) } } test("usage with iterators - only gets and only puts") { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 @@ -130,15 +130,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } } - val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) - val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) - val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -149,8 +149,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val coordinatorRef = sqlContext.streams.stateStoreCoordinator coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") @@ -159,7 +159,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) @@ -178,16 +178,20 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { - withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContext = new SQLContext(sc) + + withSparkSession( + SparkSession.builder + .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) + .getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 67e44849ca..9eff42ab2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -98,7 +98,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } } - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -239,7 +239,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -269,7 +269,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -310,7 +310,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -340,16 +340,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size - sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val previousStageNumber = spark.listener.stageIdToStageMetrics.size + spark.sparkContext.parallelize(1 to 10).foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } test("SPARK-13055: history listener only tracks SQL metrics") { @@ -401,8 +401,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { val sc = new SparkContext(conf) try { SQLContext.clearSqlListener() - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = new SQLContext(sc) + import spark.implicits._ // Run 100 successful executions and 100 failed executions. // Each execution only has one job and one stage. for (i <- 0 until 100) { @@ -418,12 +418,12 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) + assert(spark.listener.getCompletedExecutions.size <= 50) + assert(spark.listener.getFailedExecutions.size <= 50) // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + assert(spark.listener.executionIdToData.size <= 100) + assert(spark.listener.jobIdToExecutionId.size <= 100) + assert(spark.listener.stageIdToStageMetrics.size <= 100) } finally { sc.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 73c2076a30..56f848b9a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -37,13 +37,12 @@ class CatalogSuite with BeforeAndAfterEach with SharedSQLContext { - private def sparkSession: SparkSession = sqlContext.sparkSession - private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog + private def sessionCatalog: SessionCatalog = spark.sessionState.catalog private val utils = new CatalogTestUtils { override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override def newEmptyCatalog(): ExternalCatalog = sparkSession.sharedState.externalCatalog + override def newEmptyCatalog(): ExternalCatalog = spark.sharedState.externalCatalog } private def createDatabase(name: String): Unit = { @@ -87,8 +86,8 @@ class CatalogSuite private def testListColumns(tableName: String, dbName: Option[String]): Unit = { val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, dbName)) val columns = dbName - .map { db => sparkSession.catalog.listColumns(db, tableName) } - .getOrElse { sparkSession.catalog.listColumns(tableName) } + .map { db => spark.catalog.listColumns(db, tableName) } + .getOrElse { spark.catalog.listColumns(tableName) } assume(tableMetadata.schema.nonEmpty, "bad test") assume(tableMetadata.partitionColumnNames.nonEmpty, "bad test") assume(tableMetadata.bucketColumnNames.nonEmpty, "bad test") @@ -108,85 +107,85 @@ class CatalogSuite } test("current database") { - assert(sparkSession.catalog.currentDatabase == "default") + assert(spark.catalog.currentDatabase == "default") assert(sessionCatalog.getCurrentDatabase == "default") createDatabase("my_db") - sparkSession.catalog.setCurrentDatabase("my_db") - assert(sparkSession.catalog.currentDatabase == "my_db") + spark.catalog.setCurrentDatabase("my_db") + assert(spark.catalog.currentDatabase == "my_db") assert(sessionCatalog.getCurrentDatabase == "my_db") val e = intercept[AnalysisException] { - sparkSession.catalog.setCurrentDatabase("unknown_db") + spark.catalog.setCurrentDatabase("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list databases") { - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) createDatabase("my_db1") createDatabase("my_db2") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db1", "my_db2")) dropDatabase("my_db1") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db2")) } test("list tables") { - assert(sparkSession.catalog.listTables().collect().isEmpty) + assert(spark.catalog.listTables().collect().isEmpty) createTable("my_table1") createTable("my_table2") createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table1", "my_table2", "my_temp_table")) dropTable("my_table1") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) } test("list tables with database") { - assert(sparkSession.catalog.listTables("default").collect().isEmpty) + assert(spark.catalog.listTables("default").collect().isEmpty) createDatabase("my_db1") createDatabase("my_db2") createTable("my_table1", Some("my_db1")) createTable("my_table2", Some("my_db2")) createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_table1", "my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_table1", Some("my_db1")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db1").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2")) val e = intercept[AnalysisException] { - sparkSession.catalog.listTables("unknown_db") + spark.catalog.listTables("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list functions") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions().collect().map(_.name).toSet)) + spark.catalog.listFunctions().collect().map(_.name).toSet)) createFunction("my_func1") createFunction("my_func2") createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) dropFunction("my_func1") dropTempFunction("my_temp_func") - val funcNames2 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(!funcNames2.contains("my_temp_func")) @@ -194,14 +193,14 @@ class CatalogSuite test("list functions with database") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions("default").collect().map(_.name).toSet)) + spark.catalog.listFunctions("default").collect().map(_.name).toSet)) createDatabase("my_db1") createDatabase("my_db2") createFunction("my_func1", Some("my_db1")) createFunction("my_func2", Some("my_db2")) createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2 = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(!funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) @@ -210,14 +209,14 @@ class CatalogSuite assert(funcNames2.contains("my_temp_func")) dropFunction("my_func1", Some("my_db1")) dropTempFunction("my_temp_func") - val funcNames1b = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2b = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1b = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2b = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(!funcNames1b.contains("my_func1")) assert(!funcNames1b.contains("my_temp_func")) assert(funcNames2b.contains("my_func2")) assert(!funcNames2b.contains("my_temp_func")) val e = intercept[AnalysisException] { - sparkSession.catalog.listFunctions("unknown_db") + spark.catalog.listFunctions("unknown_db") } assert(e.getMessage.contains("unknown_db")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index b87f482941..7ead97bbf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -33,61 +33,61 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("programmatic ways of basic setting and getting") { // Set a conf first. - sqlContext.setConf(testKey, testVal) + spark.conf.set(testKey, testVal) // Clear the conf. - sqlContext.conf.clear() + spark.wrapped.conf.clear() // After clear, only overrideConfs used by unit test should be in the SQLConf. - assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + assert(spark.conf.getAll === TestSQLContext.overrideConfs) - sqlContext.setConf(testKey, testVal) - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + spark.conf.set(testKey, testVal) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("parse SQL set commands") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() sql(s"set $testKey=$testVal") - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(sqlContext.getConf("some.property", "0") === "20") + assert(spark.conf.get("some.property", "0") === "20") sql("set some.property = 40") - assert(sqlContext.getConf("some.property", "0") === "40") + assert(spark.conf.get("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(sqlContext.getConf(key, "0") === vs) + assert(spark.conf.get(key, "0") === vs) sql(s"set $key=") - assert(sqlContext.getConf(key, "0") === "") + assert(spark.conf.get(key, "0") === "") - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("deprecated property") { - sqlContext.conf.clear() - val original = sqlContext.conf.numShufflePartitions + spark.wrapped.conf.clear() + val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) try{ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(sqlContext.conf.numShufflePartitions === 10) + assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10) } finally { sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") } } test("invalid conf value") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } @@ -95,35 +95,35 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") - assert(sqlContext.conf.targetPostShuffleInputSize === 100) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") - assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1024) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") - assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1048576) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") - assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1073741824) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") - assert(sqlContext.conf.targetPostShuffleInputSize === -1) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === -1) // Test overflow exception intercept[IllegalArgumentException] { // This value exceeds Long.MaxValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") } intercept[IllegalArgumentException] { // This value less than Long.MinValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") } - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("SparkSession can access configs set in SparkConf") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 47a1017caa..44d1b9ddda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -337,39 +337,39 @@ class JDBCSuite extends SparkFunSuite } test("Basic API") { - assert(sqlContext.read.jdbc( + assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(sqlContext.read.jdbc( + assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } test("Partitioning on column that might have null values.") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) .collect().length === 4) assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) .collect().length === 4) // partitioning on a nullable quoted column assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) .collect().length === 4) } @@ -429,9 +429,9 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types") { - val rows = sqlContext.read.jdbc( + val rows = spark.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -439,8 +439,8 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types in cache") { - val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -448,7 +448,7 @@ class JDBCSuite extends SparkFunSuite } test("test types for null value") { - val rows = sqlContext.read.jdbc( + val rows = spark.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -495,7 +495,7 @@ class JDBCSuite extends SparkFunSuite test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -629,7 +629,7 @@ class JDBCSuite extends SparkFunSuite // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); val date = java.sql.Date.valueOf("1995-01-01") - val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(0).getAs[java.sql.Timestamp](2) @@ -639,7 +639,7 @@ class JDBCSuite extends SparkFunSuite test("test credentials in the properties are not in plan output") { val df = sql("SELECT * FROM parts") val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } // test the JdbcRelation toString output @@ -649,9 +649,9 @@ class JDBCSuite extends SparkFunSuite } test("test credentials in the connection url are not in the plan output") { - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee66931..48fa5f9822 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -88,50 +88,50 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) assert( - 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -141,14 +141,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 92061133cd..754aa32a30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -24,7 +24,7 @@ private[sql] abstract class DataSourceTest extends QueryTest { // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(sqlContext.sparkContext) + val ctx = new SQLContext(spark.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a9b1970a7c..a2decadbe0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = sqlContext.range(100).select($"id", lit(1).as("data")) + val df = spark.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = sqlContext.range(100) + val base = spark.range(100) val df = base.union(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -58,7 +58,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => val path = f.getAbsolutePath Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) - assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 3d69c8a187..a743cdde40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -41,13 +41,13 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with override val streamingTimeout = 20.seconds before { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } after { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } testQuietly("listing") { @@ -57,26 +57,26 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with withQueriesOn(ds1, ds2, ds3) { queries => require(queries.size === 3) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) val (q1, q2, q3) = (queries(0), queries(1), queries(2)) - assert(sqlContext.streams.get(q1.name).eq(q1)) - assert(sqlContext.streams.get(q2.name).eq(q2)) - assert(sqlContext.streams.get(q3.name).eq(q3)) + assert(spark.streams.get(q1.name).eq(q1)) + assert(spark.streams.get(q2.name).eq(q2)) + assert(spark.streams.get(q3.name).eq(q3)) intercept[IllegalArgumentException] { - sqlContext.streams.get("non-existent-name") + spark.streams.get("non-existent-name") } q1.stop() - assert(sqlContext.streams.active.toSet === Set(q2, q3)) + assert(spark.streams.active.toSet === Set(q2, q3)) val ex1 = withClue("no error while getting non-active query") { intercept[IllegalArgumentException] { - sqlContext.streams.get(q1.name) + spark.streams.get(q1.name) } } assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") - assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(spark.streams.get(q2.name).eq(q2)) m2.addData(0) // q2 should terminate with error @@ -86,11 +86,11 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with } withClue("no error while getting non-active query") { intercept[IllegalArgumentException] { - sqlContext.streams.get(q2.name).eq(q2) + spark.streams.get(q2.name).eq(q2) } } - assert(sqlContext.streams.active.toSet === Set(q3)) + assert(spark.streams.active.toSet === Set(q3)) } } @@ -98,7 +98,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(5)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking testAwaitAnyTermination(ExpectBlocked) @@ -112,7 +112,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectNotBlocked) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate a query asynchronously with exception and see awaitAnyTermination throws @@ -125,7 +125,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectException[SparkException]) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws @@ -144,7 +144,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(6)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking or non-blocking depending on timeout values testAwaitAnyTermination( @@ -173,7 +173,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with ExpectNotBlocked, awaitTimeout = 4 seconds, expectedReturnedValue = true) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination( ExpectBlocked, awaitTimeout = 4 seconds, @@ -196,7 +196,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testBehaviorFor = 4 seconds) // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q3 = stopRandomQueryAsync(2 seconds, withError = true) testAwaitAnyTermination( ExpectNotBlocked, @@ -214,7 +214,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws // the exception - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) testAwaitAnyTermination( @@ -238,7 +238,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val df = ds.toDF val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - query = sqlContext + query = spark .streams .startQuery( StreamExecution.nextName, @@ -272,10 +272,10 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with def awaitTermFunc(): Unit = { if (awaitTimeout != null && awaitTimeout.toMillis > 0) { - val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + val returnedValue = spark.streams.awaitAnyTermination(awaitTimeout.toMillis) assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") } else { - sqlContext.streams.awaitAnyTermination() + spark.streams.awaitAnyTermination() } } @@ -287,7 +287,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with import scala.concurrent.ExecutionContext.Implicits.global - val activeQueries = sqlContext.streams.active + val activeQueries = spark.streams.active val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) Future { Thread.sleep(stopAfter.toMillis) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index c7b2b99822..cb53b2b1aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -54,18 +54,18 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = { LastOptions.parameters = parameters LastOptions.schema = schema - LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters) ("dummySource", fakeSchema) } override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -73,14 +73,14 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { LastOptions.parameters = parameters LastOptions.schema = schema LastOptions.mockStreamSourceProvider.createSource( - sqlContext, metadataPath, schema, providerName, parameters) + spark, metadataPath, schema, providerName, parameters) new Source { override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - import sqlContext.implicits._ + import spark.implicits._ Seq[Int]().toDS().toDF() } @@ -88,12 +88,12 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } override def createSink( - sqlContext: SQLContext, + spark: SQLContext, parameters: Map[String, String], partitionColumns: Seq[String]): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } @@ -107,11 +107,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath after { - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("resolve default source") { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream() .write @@ -122,7 +122,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("resolve full class") { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test.DefaultSource") .stream() .write @@ -136,7 +136,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) @@ -164,7 +164,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("partitioning") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() @@ -204,7 +204,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("stream paths") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .stream("/test") @@ -223,7 +223,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("test different data types for options") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) @@ -253,7 +253,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Start a query with a specific name */ def startQueryWithName(name: String = ""): ContinuousQuery = { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") .write @@ -265,7 +265,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Start a query without specifying a name */ def startQueryWithoutName(): ContinuousQuery = { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") .write @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Get the names of active streams */ def activeStreamNames: Set[String] = { - val streams = sqlContext.streams.active + val streams = spark.streams.active val names = streams.map(_.name).toSet assert(streams.length === names.size, s"names of active queries are not unique: $names") names @@ -307,11 +307,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q1.stop() val q5 = startQueryWithName("name") assert(activeStreamNames.contains("name")) - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("trigger") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") @@ -339,11 +339,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val checkpointLocation = newMetadataDir - val df1 = sqlContext.read + val df1 = spark.read .format("org.apache.spark.sql.streaming.test") .stream() - val df2 = sqlContext.read + val df2 = spark.read .format("org.apache.spark.sql.streaming.test") .stream() @@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.wrapped, checkpointLocation + "/sources/0", None, "org.apache.spark.sql.streaming.test", Map.empty) verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.wrapped, checkpointLocation + "/sources/1", None, "org.apache.spark.sql.streaming.test", @@ -372,35 +372,35 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath test("check trigger() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) assert(e.getMessage == "trigger() can only be called on continuous queries;") } test("check queryName() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.queryName("queryName")) assert(e.getMessage == "queryName() can only be called on continuous queries;") } test("check startStream() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.startStream()) assert(e.getMessage == "startStream() can only be called on continuous queries;") } test("check startStream(path) can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.startStream("non_exist_path")) assert(e.getMessage == "startStream() can only be called on continuous queries;") } test("check mode(SaveMode) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -409,7 +409,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check mode(string) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -418,7 +418,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check bucketBy() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -427,7 +427,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check sortBy() can only be called on non-continuous queries;") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -436,7 +436,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check save(path) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -445,7 +445,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check save() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -454,7 +454,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check insertInto() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -463,7 +463,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check saveAsTable() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -472,7 +472,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check jdbc() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -481,7 +481,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check json() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -490,7 +490,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check parquet() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -499,7 +499,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check orc() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -508,7 +508,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check text() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -517,7 +517,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check csv() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index e937fc3e87..6238b74ffa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -40,11 +40,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration val fileFormat = new parquet.DefaultSource() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = sqlContext + val df = spark .range(start, end, 1, numPartitions) .select($"id", lit(100).as("data")) val writer = new FileStreamSinkWriter( @@ -56,7 +56,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val files1 = writeRange(0, 10, 2) assert(files1.size === 2, s"unexpected number of files: $files1") checkFilesExist(path, files1, "file not written") - checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) // Append and check whether new files are written correctly and old files still exist val files2 = writeRange(10, 20, 3) @@ -64,7 +64,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { assert(files2.intersect(files1).isEmpty, "old files returned") checkFilesExist(path, files2, s"New file not written") checkFilesExist(path, files1, s"Old file not found") - checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) } test("FileStreamSinkWriter - partitioned data") { @@ -72,11 +72,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration val fileFormat = new parquet.DefaultSource() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = sqlContext + val df = spark .range(start, end, 1, numPartitions) .flatMap(x => Iterator(x, x, x)).toDF("id") .select($"id", lit(100).as("data1"), lit(1000).as("data2")) @@ -103,7 +103,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { checkOneFileWrittenPerKey(0 until 10, files1) val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1) // Append and check whether new files are written correctly and old files still exist val files2 = writeRange(0, 20, 3) @@ -114,7 +114,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { checkOneFileWrittenPerKey(0 until 20, files2) val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1 ++ answer2) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2) } test("FileStreamSink - unpartitioned writing and batch reading") { @@ -139,7 +139,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { query.processAllAvailable() } - val outputDf = sqlContext.read.parquet(outputDir).as[Int] + val outputDf = spark.read.parquet(outputDir).as[Int] checkDataset(outputDf, 1, 2, 3) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a62852b512..4b95d65627 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -103,9 +103,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { val reader = if (schema.isDefined) { - sqlContext.read.format(format).schema(schema.get) + spark.read.format(format).schema(schema.get) } else { - sqlContext.read.format(format) + spark.read.format(format) } reader.stream(path) } @@ -149,7 +149,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { format: Option[String], path: Option[String], schema: Option[StructType] = None): StructType = { - val reader = sqlContext.read + val reader = spark.read format.foreach(reader.format) schema.foreach(reader.schema) val df = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 50703e532f..4efb7cf52d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -100,7 +100,7 @@ class FileStressSuite extends StreamTest with SharedSQLContext { } writer.start() - val input = sqlContext.read.format("text").stream(inputDir) + val input = spark.read.format("text").stream(inputDir) def startStream(): ContinuousQuery = { val output = input @@ -150,6 +150,6 @@ class FileStressSuite extends StreamTest with SharedSQLContext { streamThread.join() logError(s"Stream restarted $failures times.") - assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords) + assert(spark.read.parquet(outputDir).distinct().count() == numRecords) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 74ca3977d6..09c35bbf2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -44,13 +44,13 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3) input.addData(4, 5, 6) query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3, 4, 5, 6) query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index bcd3cba55a..6a8b280174 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -94,7 +94,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { .startStream(outputDir.getAbsolutePath) try { query.processAllAvailable() - val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] checkDataset[Long](outputDf, (0L to 10L).toArray: _*) } finally { query.stop() @@ -103,7 +103,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { } } - val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + val df = spark.read.format(classOf[FakeDefaultSource].getName).stream() assertDF(df) assertDF(df) } @@ -162,13 +162,13 @@ class FakeDefaultSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -190,7 +190,7 @@ class FakeDefaultSource extends StreamSourceProvider { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 - sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 7fa6760b71..03369c5a48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} /** * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def sqlContext: SQLContext + protected def spark: SparkSession // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.wrapped } import internalImplicits._ @@ -39,21 +39,21 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() df.registerTempTable("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } protected lazy val testData2: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -65,7 +65,7 @@ private[sql] trait SQLTestData { self => } protected lazy val testData3: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -73,14 +73,14 @@ private[sql] trait SQLTestData { self => } protected lazy val negativeData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -92,7 +92,7 @@ private[sql] trait SQLTestData { self => } protected lazy val decimalData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -104,7 +104,7 @@ private[sql] trait SQLTestData { self => } protected lazy val binaryData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: @@ -115,7 +115,7 @@ private[sql] trait SQLTestData { self => } protected lazy val upperCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -127,7 +127,7 @@ private[sql] trait SQLTestData { self => } protected lazy val lowerCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -137,7 +137,7 @@ private[sql] trait SQLTestData { self => } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -145,7 +145,7 @@ private[sql] trait SQLTestData { self => } protected lazy val mapData: RDD[MapData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -156,13 +156,13 @@ private[sql] trait SQLTestData { self => } protected lazy val repeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -170,7 +170,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -180,7 +180,7 @@ private[sql] trait SQLTestData { self => } protected lazy val allNulls: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -190,7 +190,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullStrings: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -199,13 +199,13 @@ private[sql] trait SQLTestData { self => } protected lazy val tableName: DataFrame = { - val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -214,13 +214,13 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -228,7 +228,7 @@ private[sql] trait SQLTestData { self => } protected lazy val salary: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -236,7 +236,7 @@ private[sql] trait SQLTestData { self => } protected lazy val complexData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -245,7 +245,7 @@ private[sql] trait SQLTestData { self => } protected lazy val courseSales: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( CourseSales("dotNET", 2012, 10000) :: CourseSales("Java", 2012, 20000) :: CourseSales("dotNET", 2012, 5000) :: @@ -259,7 +259,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(spark != null, "attempted to initialize test data before SparkSession.") emptyTestData testData testData2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6d2b95e83a..a49a8c9f2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -50,23 +50,23 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def sparkContext = sqlContext.sparkContext + protected def sparkContext = spark.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = sqlContext.sql _ + protected lazy val sql = spark.sql _ /** * A helper object for importing SQL implicits. * - * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * Note that the alternative of importing `spark.implicits._` is not possible here. * This is because we create the [[SQLContext]] immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.wrapped } /** @@ -92,12 +92,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -138,9 +138,9 @@ private[sql] trait SQLTestUtils // temp tables that never got created. functions.foreach { case (functionName, isTemporary) => val withTemporary = if (isTemporary) "TEMPORARY" else "" - sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") assert( - !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), s"Function $functionName should have been dropped. But, it still exists.") } } @@ -153,7 +153,7 @@ private[sql] trait SQLTestUtils try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove // temp tables that never got created. - try tableNames.foreach(sqlContext.dropTempTable) catch { + try tableNames.foreach(spark.catalog.dropTempTable) catch { case _: NoSuchTableException => } } @@ -165,7 +165,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + spark.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -176,7 +176,7 @@ private[sql] trait SQLTestUtils protected def withView(viewNames: String*)(f: => Unit): Unit = { try f finally { viewNames.foreach { name => - sqlContext.sql(s"DROP VIEW IF EXISTS $name") + spark.sql(s"DROP VIEW IF EXISTS $name") } } } @@ -191,12 +191,12 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + spark.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally spark.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -204,8 +204,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sessionState.catalog.setCurrentDatabase(db) - try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") } /** @@ -221,7 +221,7 @@ private[sql] trait SQLTestUtils .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) - sqlContext.createDataFrame(childRDD, schema) + spark.createDataFrame(childRDD, schema) } /** @@ -229,7 +229,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, plan) + Dataset.ofRows(spark, plan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 914c6a5509..620bfa995a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,37 +17,42 @@ package org.apache.spark.sql.test -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.SparkConf +import org.apache.spark.sql.{SparkSession, SQLContext} /** - * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ trait SharedSQLContext extends SQLTestUtils { protected val sparkConf = new SparkConf() /** - * The [[TestSQLContext]] to use for all tests in this suite. + * The [[TestSparkSession]] to use for all tests in this suite. * * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local * mode with the default test configurations. */ - private var _ctx: TestSQLContext = null + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected implicit def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _spark.wrapped /** - * Initialize the [[TestSQLContext]]. + * Initialize the [[TestSparkSession]]. */ protected override def beforeAll(): Unit = { SQLContext.clearSqlListener() - if (_ctx == null) { - _ctx = new TestSQLContext(sparkConf) + if (_spark == null) { + _spark = new TestSparkSession(sparkConf) } // Ensure we have initialized the context before calling parent code super.beforeAll() @@ -58,9 +63,9 @@ trait SharedSQLContext extends SQLTestUtils { */ protected override def afterAll(): Unit = { try { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null + if (_spark != null) { + _spark.stop() + _spark = null } } finally { super.afterAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 5ef80b9aa3..785e3452a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,44 +18,32 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SQLConf} /** - * A special [[SQLContext]] prepared for testing. + * A special [[SparkSession]] prepared for testing. */ -private[sql] class TestSQLContext( - @transient override val sparkSession: SparkSession, - isRootContext: Boolean) - extends SQLContext(sparkSession, isRootContext) { self => - - def this(sc: SparkContext) { - this(new TestSparkSession(sc), true) - } - +private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) } def this() { - this(new SparkConf) + this { + val conf = new SparkConf() + conf.set("spark.sql.testkey", "true") + + val spark = SparkSession.builder + .master("local[2]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() + spark.sparkContext + } } - // Needed for Java tests - def loadTestData(): Unit = { - testData.loadTestData() - } - - private object testData extends SQLTestData { - protected override def sqlContext: SQLContext = self - } - -} - - -private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => - @transient protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { @@ -70,6 +58,14 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } } + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def spark: SparkSession = self + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala index 54acd4db3c..8788898fc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -36,11 +36,11 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with import testImplicits._ after { - sqlContext.streams.active.foreach(_.stop()) - assert(sqlContext.streams.active.isEmpty) + spark.streams.active.foreach(_.stop()) + assert(spark.streams.active.isEmpty) assert(addedListeners.isEmpty) // Make sure we don't leak any events to the next test - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) } test("single listener") { @@ -112,17 +112,17 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with val listener1 = new QueryStatusCollector val listener2 = new QueryStatusCollector - sqlContext.streams.addListener(listener1) + spark.streams.addListener(listener1) assert(isListenerActive(listener1) === true) assert(isListenerActive(listener2) === false) - sqlContext.streams.addListener(listener2) + spark.streams.addListener(listener2) assert(isListenerActive(listener1) === true) assert(isListenerActive(listener2) === true) - sqlContext.streams.removeListener(listener1) + spark.streams.removeListener(listener1) assert(isListenerActive(listener1) === false) assert(isListenerActive(listener2) === true) } finally { - addedListeners.foreach(sqlContext.streams.removeListener) + addedListeners.foreach(spark.streams.removeListener) } } @@ -146,18 +146,18 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { try { failAfter(1 minute) { - sqlContext.streams.addListener(listener) + spark.streams.addListener(listener) body } } finally { - sqlContext.streams.removeListener(listener) + spark.streams.removeListener(listener) } } private def addedListeners(): Array[ContinuousQueryListener] = { val listenerBusMethod = PrivateMethod[ContinuousQueryListenerBus]('listenerBus) - val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() + val listenerBus = spark.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 8a0578c1ff..3ae5ce610d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -39,7 +39,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += ((funcName, qe, duration)) } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j") df.select("i").collect() @@ -55,7 +55,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) assert(metrics(1)._3 > 0) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("execute callback functions when a DataFrame action failed") { @@ -68,7 +68,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { // Only test failed case here, so no need to implement `onSuccess` override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") @@ -82,7 +82,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0)._2.analyzed.isInstanceOf[Project]) assert(metrics(0)._3.getMessage == e.getMessage) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("get numRows metrics by callback") { @@ -99,7 +99,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += metric.value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() df.collect() @@ -111,7 +111,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1) === 1) assert(metrics(2) === 2) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never @@ -131,10 +131,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += bottomAgg.longMetric("dataSize").value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val sparkListener = new SaveInfoListener - sqlContext.sparkContext.addSparkListener(sparkListener) + spark.sparkContext.addSparkListener(sparkListener) val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j") df.groupBy("i").count().collect() @@ -157,6 +157,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0) == topAggDataSize) assert(metrics(1) == bottomAggDataSize) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 154ada3daa..9bf84ab1fb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.hive.test import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.SparkSession import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { - protected val sqlContext: SQLContext = TestHive + protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected override def afterAll(): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index a7782abc39..72736ee55b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -34,13 +34,13 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { val bytes = Array[Byte](1, 2, 3, 4) Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0") - sqlContext + spark .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("t1") - sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") } override protected def afterAll(): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 34c2773581..9abefa5f28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -33,16 +33,16 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { sql("DROP TABLE IF EXISTS parquet_t2") sql("DROP TABLE IF EXISTS t0") - sqlContext.range(10).write.saveAsTable("parquet_t0") + spark.range(10).write.saveAsTable("parquet_t0") sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0") - sqlContext + spark .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("parquet_t1") - sqlContext + spark .range(10) .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) .write @@ -52,7 +52,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1)) } - sqlContext + spark .range(10) .select( createArray('id).as("arr"), @@ -394,7 +394,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Seq("orc", "json", "parquet").foreach { format => val tableName = s"${format}_parquet_t0" withTable(tableName) { - sqlContext.range(10).write.format(format).saveAsTable(tableName) + spark.range(10).write.format(format).saveAsTable(tableName) checkHiveQl(s"SELECT id FROM $tableName") } } @@ -458,7 +458,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("plans with non-SQL expressions") { - sqlContext.udf.register("foo", (_: Int) * 2) + spark.udf.register("foo", (_: Int) * 2) intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala index 27c9e992de..31755f56ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala @@ -64,7 +64,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext.sparkSession, plan)) + checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 61910b8e6b..093cd3a96c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -30,8 +30,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd override protected def beforeEach(): Unit = { super.beforeEach() - if (sqlContext.tableNames().contains("src")) { - sqlContext.dropTempTable("src") + if (spark.wrapped.tableNames().contains("src")) { + spark.catalog.dropTempTable("src") } Seq((1, "")).toDF("key", "value").registerTempTable("src") Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -39,8 +39,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd override protected def afterEach(): Unit = { try { - sqlContext.dropTempTable("src") - sqlContext.dropTempTable("dupAttributes") + spark.catalog.dropTempTable("src") + spark.catalog.dropTempTable("dupAttributes") } finally { super.afterEach() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a717a9978e..bfe559f0b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -555,7 +555,7 @@ object SparkSQLConfTest extends Logging { object SPARK_9757 extends QueryTest { import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -567,7 +567,7 @@ object SPARK_9757 extends QueryTest { .set("spark.ui.enabled", "false")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession import hiveContext.implicits._ val dir = Utils.createTempDir() @@ -602,7 +602,7 @@ object SPARK_9757 extends QueryTest { object SPARK_11009 extends QueryTest { import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -613,10 +613,10 @@ object SPARK_11009 extends QueryTest { .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession try { - val df = sqlContext.range(1 << 20) + val df = spark.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0") @@ -633,7 +633,7 @@ object SPARK_14244 extends QueryTest { import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -644,13 +644,13 @@ object SPARK_14244 extends QueryTest { .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession import hiveContext.implicits._ try { val window = Window.orderBy('id) - val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) + val df = spark.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) checkAnswer(df, Seq(Row(0.5D), Row(1.0D))) } finally { sparkContext.stop() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 52aba328de..82d3e49f92 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -251,7 +251,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") // this will pick up the output partitioning from the table definition - sqlContext.table("source").write.insertInto("partitioned") + spark.table("source").write.insertInto("partitioned") checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq) } @@ -272,7 +272,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql( """CREATE TABLE partitioned (id bigint, data string) |PARTITIONED BY (part1 string, part2 string)""".stripMargin) - sqlContext.table("source").write.insertInto("partitioned") + spark.table("source").write.insertInto("partitioned") checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) } @@ -283,7 +283,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") - val logical = InsertIntoTable(sqlContext.table("partitioned").logicalPlan, + val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) assert(!logical.resolved, "Should not resolve: missing partition data") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 78c8f0043d..b2a80e70be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -374,7 +374,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val expectedPath = sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) // It is a managed table when we do not specify the location. @@ -701,7 +701,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // Manually create a metastore data source table. CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + sparkSession = spark, tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -891,18 +891,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("SPARK-8156:create table to specific database by 'use dbname' ") { val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - sqlContext.sql("""create database if not exists testdb8156""") - sqlContext.sql("""use testdb8156""") + spark.sql("""create database if not exists testdb8156""") + spark.sql("""use testdb8156""") df.write .format("parquet") .mode(SaveMode.Overwrite) .saveAsTable("ttt3") checkAnswer( - sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + spark.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), Row("ttt3", false)) - sqlContext.sql("""use default""") - sqlContext.sql("""drop database if exists testdb8156 CASCADE""") + spark.sql("""use default""") + spark.sql("""drop database if exists testdb8156 CASCADE""") } @@ -911,7 +911,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + sparkSession = spark, tableIdent = TableIdentifier("not_skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -926,7 +926,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .forall(column => CatalystSqlParser.parseDataType(column.dataType) == StringType)) CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + sparkSession = spark, tableIdent = TableIdentifier("skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 850cb1eda5..6c9ce208db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - private lazy val df = sqlContext.range(10).coalesce(1).toDF() + private lazy val df = spark.range(10).coalesce(1).toDF() private def checkTablePath(dbName: String, tableName: String): Unit = { val metastoreTable = hiveContext.sharedState.externalCatalog.getTable(dbName, tableName) @@ -36,12 +36,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df) + assert(spark.wrapped.tableNames().contains("t")) + checkAnswer(spark.table("t"), df) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") } @@ -50,8 +50,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle test(s"saveAsTable() to non-default database - without USE - Overwrite") { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") } @@ -64,9 +64,9 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle val path = dir.getCanonicalPath df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - sqlContext.createExternalTable("t", path, "parquet") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table("t"), df) + spark.catalog.createExternalTable("t", path, "parquet") + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table("t"), df) sql( s""" @@ -76,8 +76,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(sqlContext.tableNames(db).contains("t1")) - checkAnswer(sqlContext.table("t1"), df) + assert(spark.wrapped.tableNames(db).contains("t1")) + checkAnswer(spark.table("t1"), df) } } } @@ -88,10 +88,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempPath { dir => val path = dir.getCanonicalPath df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - sqlContext.createExternalTable(s"$db.t", path, "parquet") + spark.catalog.createExternalTable(s"$db.t", path, "parquet") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) sql( s""" @@ -101,8 +101,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(sqlContext.tableNames(db).contains("t1")) - checkAnswer(sqlContext.table(s"$db.t1"), df) + assert(spark.wrapped.tableNames(db).contains("t1")) + checkAnswer(spark.table(s"$db.t1"), df) } } } @@ -112,12 +112,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df.union(df)) + assert(spark.wrapped.tableNames().contains("t")) + checkAnswer(spark.table("t"), df.union(df)) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -127,8 +127,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -138,10 +138,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(spark.wrapped.tableNames().contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + checkAnswer(spark.table(s"$db.t"), df.union(df)) } } } @@ -150,13 +150,13 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(spark.wrapped.tableNames().contains("t")) } - assert(sqlContext.tableNames(db).contains("t")) + assert(spark.wrapped.tableNames(db).contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + checkAnswer(spark.table(s"$db.t"), df.union(df)) } } @@ -164,10 +164,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql("CREATE TABLE t (key INT)") - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table("t"), spark.emptyDataFrame) } - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame) } } @@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql(s"CREATE TABLE t (key INT)") - assert(sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(spark.wrapped.tableNames().contains("t")) + assert(!spark.wrapped.tableNames("default").contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(sqlContext.tableNames(db).contains("t")) + assert(!spark.wrapped.tableNames().contains("t")) + assert(spark.wrapped.tableNames(db).contains("t")) activateDatabase(db) { sql(s"DROP TABLE t") - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(!spark.wrapped.tableNames().contains("t")) + assert(!spark.wrapped.tableNames("default").contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames(db).contains("t")) + assert(!spark.wrapped.tableNames().contains("t")) + assert(!spark.wrapped.tableNames(db).contains("t")) } } @@ -208,18 +208,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table("t"), spark.emptyDataFrame) df.write.parquet(s"$path/p=1") sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") - checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + checkAnswer(spark.table("t"), df.withColumn("p", lit(1))) df.write.parquet(s"$path/p=2") sql("ALTER TABLE t ADD PARTITION (p=2)") hiveContext.sessionState.refreshTable("t") checkAnswer( - sqlContext.table("t"), + spark.table("t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } @@ -240,18 +240,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame) df.write.parquet(s"$path/p=1") sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") sql(s"REFRESH TABLE $db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + checkAnswer(spark.table(s"$db.t"), df.withColumn("p", lit(1))) df.write.parquet(s"$path/p=2") sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") hiveContext.sessionState.refreshTable(s"$db.t") checkAnswer( - sqlContext.table(s"$db.t"), + spark.table(s"$db.t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index af4dc1beec..3f6418cbe8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -70,12 +70,12 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi |$ddl """.stripMargin) - sqlContext.sql(ddl) + spark.sql(ddl) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + val schema = spark.table("parquet_compat").schema + val rowRDD = spark.sparkContext.parallelize(rows).coalesce(1) + spark.createDataFrame(rowRDD, schema).registerTempTable("data") + spark.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") } } @@ -84,7 +84,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. // Have to assume all BINARY values are strings here. withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(path), rows) + checkAnswer(spark.read.parquet(path), rows) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 0ba72b033f..0f416eb24d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -177,23 +177,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") data3.write.saveAsTable("agg3") - val emptyDF = sqlContext.createDataFrame( + val emptyDF = spark.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) emptyDF.registerTempTable("emptyTable") // Register UDAFs - sqlContext.udf.register("mydoublesum", new MyDoubleSum) - sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) - sqlContext.udf.register("longProductSum", new LongProductSum) + spark.udf.register("mydoublesum", new MyDoubleSum) + spark.udf.register("mydoubleavg", new MyDoubleAvg) + spark.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { try { - sqlContext.sql("DROP TABLE IF EXISTS agg1") - sqlContext.sql("DROP TABLE IF EXISTS agg2") - sqlContext.sql("DROP TABLE IF EXISTS agg3") - sqlContext.dropTempTable("emptyTable") + spark.sql("DROP TABLE IF EXISTS agg1") + spark.sql("DROP TABLE IF EXISTS agg2") + spark.sql("DROP TABLE IF EXISTS agg3") + spark.catalog.dropTempTable("emptyTable") } finally { super.afterAll() } @@ -210,7 +210,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(value), @@ -227,7 +227,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(value), @@ -246,7 +246,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // If there is a GROUP BY clause and the table is empty, there is no output. checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(value), @@ -266,7 +266,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("null literal") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(null), @@ -282,7 +282,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("only do grouping") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key |FROM agg1 @@ -291,7 +291,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT DISTINCT value1, key |FROM agg2 @@ -308,7 +308,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT value1, key |FROM agg2 @@ -326,7 +326,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT DISTINCT key |FROM agg3 @@ -341,7 +341,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(Seq[Integer](3)) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT value1, key |FROM agg3 @@ -363,7 +363,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("case in-sensitive resolution") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value), kEY - 100 |FROM agg1 @@ -372,7 +372,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT sum(distinct value1), kEY - 100, count(distinct value1) |FROM agg2 @@ -381,7 +381,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT valUe * key - 100 |FROM agg1 @@ -397,7 +397,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("test average no key in output") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value) |FROM agg1 @@ -408,7 +408,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("test average") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key, avg(value) |FROM agg1 @@ -417,7 +417,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key, mean(value) |FROM agg1 @@ -426,7 +426,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value), key |FROM agg1 @@ -435,7 +435,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value) + 1.5, key + 10 |FROM agg1 @@ -444,7 +444,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value) FROM agg1 """.stripMargin), @@ -456,7 +456,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // deterministic. withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | first_valUE(key), @@ -472,7 +472,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | first_valUE(key), @@ -491,7 +491,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("udaf") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -511,7 +511,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("interpreted aggregate function") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(value), key |FROM agg1 @@ -520,14 +520,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(value) FROM agg1 """.stripMargin), Row(89.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(null) """.stripMargin), @@ -536,7 +536,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("interpreted and expression-based aggregation functions") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(value), key, avg(value) |FROM agg1 @@ -548,7 +548,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(30.0, null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | mydoublesum(value + 1.5 * key), @@ -568,7 +568,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("single distinct column set") { // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | min(distinct value1), @@ -581,7 +581,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | mydoubleavg(distinct value1), @@ -600,7 +600,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -618,7 +618,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | count(value1), @@ -637,7 +637,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("single distinct multiple columns set") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -653,7 +653,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("multiple distinct multiple columns sets") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -681,7 +681,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("test count") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | count(value2), @@ -704,7 +704,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(0, null, 1, 1, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | count(value2), @@ -786,28 +786,28 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te covar_tab.registerTempTable("covar_tab") checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT corr(b, c) FROM covar_tab WHERE a < 1 """.stripMargin), Row(null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT corr(b, c) FROM covar_tab WHERE a < 3 """.stripMargin), Row(null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), Row(Double.NaN) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a """.stripMargin), @@ -818,7 +818,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(5, Double.NaN) :: Row(6, Double.NaN) :: Nil) - val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) + val corr7 = spark.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) } @@ -852,7 +852,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("no aggregation function (SPARK-11486)") { - val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") + val df = spark.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() .groupBy().count() checkAnswer(df, Row(20) :: Nil) @@ -906,8 +906,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 1) - val df = sqlContext.createDataFrame(rdd, schema) + val rdd = spark.sparkContext.parallelize(data, 1) + val df = spark.createDataFrame(rdd, schema) val allColumns = df.schema.fields.map(f => col(f.name)) val expectedAnswer = @@ -924,7 +924,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("udaf without specifying inputSchema") { withTempTable("noInputSchemaUDAF") { - sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) + spark.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) val data = Row(1, Seq(Row(1), Row(2), Row(3))) :: @@ -935,13 +935,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te StructField("key", IntegerType) :: StructField("myArray", ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil) - sqlContext.createDataFrame( + spark.createDataFrame( sparkContext.parallelize(data, 2), schema) .registerTempTable("noInputSchemaUDAF") checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key, noInputSchema(myArray) |FROM noInputSchemaUDAF @@ -950,7 +950,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1, 21) :: Row(2, -10) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT noInputSchema(myArray) |FROM noInputSchemaUDAF @@ -976,7 +976,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. // todo: remove it? - val newActual = Dataset.ofRows(sqlContext.sparkSession, actual.logicalPlan) + val newActual = Dataset.ofRows(spark, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0f23949d98..6dcc404636 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -36,7 +36,7 @@ class HiveDDLSuite override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test - sqlContext.sessionState.catalog.reset() + spark.sessionState.catalog.reset() } finally { super.afterEach() } @@ -212,7 +212,7 @@ class HiveDDLSuite test("drop views") { withTable("tab1") { val tabName = "tab1" - sqlContext.range(10).write.saveAsTable("tab1") + spark.range(10).write.saveAsTable("tab1") withView("view1") { val viewName = "view1" @@ -233,7 +233,7 @@ class HiveDDLSuite test("alter views - rename") { val tabName = "tab1" withTable(tabName) { - sqlContext.range(10).write.saveAsTable(tabName) + spark.range(10).write.saveAsTable(tabName) val oldViewName = "view1" val newViewName = "view2" withView(oldViewName, newViewName) { @@ -252,7 +252,7 @@ class HiveDDLSuite test("alter views - set/unset tblproperties") { val tabName = "tab1" withTable(tabName) { - sqlContext.range(10).write.saveAsTable(tabName) + spark.range(10).write.saveAsTable(tabName) val viewName = "view1" withView(viewName) { val catalog = hiveContext.sessionState.catalog @@ -290,7 +290,7 @@ class HiveDDLSuite test("alter views and alter table - misuse") { val tabName = "tab1" withTable(tabName) { - sqlContext.range(10).write.saveAsTable(tabName) + spark.range(10).write.saveAsTable(tabName) val oldViewName = "view1" val newViewName = "view2" withView(oldViewName, newViewName) { @@ -354,7 +354,7 @@ class HiveDDLSuite test("drop view using drop table") { withTable("tab1") { - sqlContext.range(10).write.saveAsTable("tab1") + spark.range(10).write.saveAsTable("tab1") withView("view1") { sql("CREATE VIEW view1 AS SELECT * FROM tab1") val message = intercept[AnalysisException] { @@ -383,7 +383,7 @@ class HiveDDLSuite } private def createDatabaseWithLocation(tmpDir: File, dirExists: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val dbName = "db1" val tabName = "tab1" val fs = new Path(tmpDir.toString).getFileSystem(hiveContext.sessionState.newHadoopConf()) @@ -442,7 +442,7 @@ class HiveDDLSuite assert(!fs.exists(dbPath)) sql(s"CREATE DATABASE $dbName") - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val expectedDBLocation = "file:" + appendTrailingSlash(dbPath.toString) + s"$dbName.db" val db1 = catalog.getDatabaseMetadata(dbName) assert(db1 == CatalogDatabase( @@ -518,7 +518,7 @@ class HiveDDLSuite test("desc table for data source table") { withTable("tab1") { val tabName = "tab1" - sqlContext.range(1).write.format("json").saveAsTable(tabName) + spark.range(1).write.format("json").saveAsTable(tabName) assert(sql(s"DESC $tabName").collect().length == 1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index d07ac56586..dd4321d1b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -347,7 +347,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") } - sqlContext.dropTempTable("testUDF") + spark.catalog.dropTempTable("testUDF") } test("Hive UDF in group by") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1d597fe16d..2e4077df54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -860,7 +860,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("Sorting columns are not in Generate") { withTempTable("data") { - sqlContext.range(1, 5) + spark.range(1, 5) .select(array($"id", $"id" + 1).as("a"), $"id".as("b"), (lit(10) - $"id").as("c")) .registerTempTable("data") @@ -1081,7 +1081,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // We don't support creating a temporary table while specifying a database val message = intercept[AnalysisException] { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE db.t |USING parquet @@ -1092,7 +1092,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }.getMessage // If you use backticks to quote the name then it's OK. - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE `db.t` |USING parquet @@ -1100,12 +1100,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | path '$path' |) """.stripMargin) - checkAnswer(sqlContext.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) } } test("SPARK-10593 same column names in lateral view") { - val df = sqlContext.sql( + val df = spark.sql( """ |select |insideLayer2.json as a2 @@ -1120,7 +1120,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ignore("SPARK-10310: " + "script transformation using default input/output SerDe and record reader/writer") { - sqlContext + spark .range(5) .selectExpr("id AS a", "id AS b") .registerTempTable("test") @@ -1138,7 +1138,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } ignore("SPARK-10310: script transformation using LazySimpleSerDe") { - sqlContext + spark .range(5) .selectExpr("id AS a", "id AS b") .registerTempTable("test") @@ -1183,7 +1183,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("run sql directly on files") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() withTempPath(f => { df.write.parquet(f.getCanonicalPath) checkAnswer(sql(s"select id from parquet.`${f.getCanonicalPath}`"), @@ -1325,14 +1325,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("3" -> "30").toDF("i", "j") .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") checkAnswer( - sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + spark.read.table("tbl11453").select("i", "j").orderBy("i"), Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) // make sure case sensitivity is correct. Seq("4" -> "40").toDF("i", "j") .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") checkAnswer( - sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + spark.read.table("tbl11453").select("i", "j").orderBy("i"), Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) } } @@ -1370,7 +1370,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("multi-insert with lateral view") { withTempTable("t1") { - sqlContext.range(10) + spark.range(10) .select(array($"id", $"id" + 1).as("arr"), $"id") .registerTempTable("source") withTable("dest1", "dest2") { @@ -1388,10 +1388,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin) checkAnswer( - sqlContext.table("dest1"), + spark.table("dest1"), sql("SELECT id FROM source WHERE id > 3")) checkAnswer( - sqlContext.table("dest2"), + spark.table("dest2"), sql("SELECT col FROM source LATERAL VIEW EXPLODE(arr) exp AS col WHERE col > 3")) } } @@ -1404,7 +1404,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withTempPath { dir => withTempTable("t1", "t2") { val path = dir.getCanonicalPath - val ds = sqlContext.range(10) + val ds = spark.range(10) ds.registerTempTable("t1") sql( @@ -1415,7 +1415,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin) checkAnswer( - sqlContext.tables().select('isTemporary).filter('tableName === "t2"), + spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"), Row(true) ) @@ -1429,7 +1429,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "shouldn always be used together with PATH data source option" ) { withTempTable("t") { - sqlContext.range(10).registerTempTable("t") + spark.range(10).registerTempTable("t") val message = intercept[IllegalArgumentException] { sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index 72f9fba13d..f37037e3c7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -30,11 +30,11 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { override def beforeAll(): Unit = { // Create a simple table with two columns: id and id1 - sqlContext.range(1, 10).selectExpr("id", "id id1").write.format("json").saveAsTable("jt") + spark.range(1, 10).selectExpr("id", "id id1").write.format("json").saveAsTable("jt") } override def afterAll(): Unit = { - sqlContext.sql(s"DROP TABLE IF EXISTS jt") + spark.sql(s"DROP TABLE IF EXISTS jt") } test("nested views (interleaved with temporary views)") { @@ -277,11 +277,11 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf( SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { withTable("add_col") { - sqlContext.range(10).write.saveAsTable("add_col") + spark.range(10).write.saveAsTable("add_col") withView("v") { sql("CREATE VIEW v AS SELECT * FROM add_col") - sqlContext.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") - checkAnswer(sql("SELECT * FROM v"), sqlContext.range(10).toDF()) + spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF()) } } } @@ -291,8 +291,8 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // make sure the new flag can handle some complex cases like join and schema change. withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt1", "jt2") { - sqlContext.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") - sqlContext.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") + spark.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") + spark.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") sql("CREATE VIEW testView AS SELECT * FROM jt1 JOIN jt2 ON id1 == id2") checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala index d0e7552c12..cbbeacf6ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala @@ -353,7 +353,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi checkAnswer(actual, expected) - sqlContext.dropTempTable("nums") + spark.catalog.dropTempTable("nums") } test("SPARK-7595: Window will cause resolve failed with self join") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index b97da1ffdc..965680ff0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -75,11 +75,11 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path) checkAnswer( - sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"), + spark.read.orc(path).where("not (a = 2) or not(b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) checkAnswer( - sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"), + spark.read.orc(path).where("not (a = 2 and b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) } } @@ -94,7 +94,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { .orc(path) // Check if this is compressed as ZLIB. - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() val fs = FileSystem.getLocal(conf) val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".zlib.orc")) assert(maybeOrcFile.isDefined) @@ -102,7 +102,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { val orcReader = OrcFile.createReader(orcFilePath, OrcFile.readerOptions(conf)) assert(orcReader.getCompression == CompressionKind.ZLIB) - val copyDf = sqlContext + val copyDf = spark .read .orc(path) checkAnswer(df, copyDf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index aa9c1189db..084546f99d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -66,7 +66,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.orc(file), + spark.read.orc(file), data.toDF().collect()) } } @@ -170,7 +170,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // Hive supports zlib, snappy and none for Hive 1.2.1. test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") { withTempPath { file => - sqlContext.range(0, 10).write + spark.range(0, 10).write .option("orc.compress", "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = @@ -179,7 +179,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } withTempPath { file => - sqlContext.range(0, 10).write + spark.range(0, 10).write .option("orc.compress", "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = @@ -188,7 +188,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } withTempPath { file => - sqlContext.range(0, 10).write + spark.range(0, 10).write .option("orc.compress", "NONE") .orc(file.getCanonicalPath) val expectedCompressionKind = @@ -200,7 +200,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // Following codec is not supported in Hive 1.2.1, ignore it now ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { withTempPath { file => - sqlContext.range(0, 10).write + spark.range(0, 10).write .option("orc.compress", "LZO") .orc(file.getCanonicalPath) val expectedCompressionKind = @@ -301,12 +301,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path) - sqlContext.read.format("orc").load(path).schema("Acol") + spark.range(0, 10).select('id as "Acol").write.format("orc").save(path) + spark.read.format("orc").load(path).schema("Acol") intercept[IllegalArgumentException] { - sqlContext.read.format("orc").load(path).schema("acol") + spark.read.format("orc").load(path).schema("acol") } - checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"), + checkAnswer(spark.read.format("orc").load(path).select("acol").sort("acol"), (0 until 10).map(Row(_))) } } @@ -317,7 +317,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTable("empty_orc") { withTempTable("empty", "single") { - sqlContext.sql( + spark.sql( s"""CREATE TABLE empty_orc(key INT, value STRING) |STORED AS ORC |LOCATION '$path' @@ -328,13 +328,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because // Spark SQL ORC data source always avoids write empty ORC files. - sqlContext.sql( + spark.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM empty """.stripMargin) val errorMessage = intercept[AnalysisException] { - sqlContext.read.orc(path) + spark.read.orc(path) }.getMessage assert(errorMessage.contains("Unable to infer schema for ORC")) @@ -342,12 +342,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) singleRowDF.registerTempTable("single") - sqlContext.sql( + spark.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM single """.stripMargin) - val df = sqlContext.read.orc(path) + val df = spark.read.orc(path) assert(df.schema === singleRowDF.schema.asNullable) checkAnswer(df, singleRowDF) } @@ -373,7 +373,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) - val df = sqlContext.read.orc(path) + val df = spark.read.orc(path) def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { val sourceDf = stripSparkFilter(df.where(pred)) @@ -415,7 +415,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTable("dummy_orc") { withTempTable("single") { - sqlContext.sql( + spark.sql( s"""CREATE TABLE dummy_orc(key INT, value STRING) |STORED AS ORC |LOCATION '$path' @@ -424,12 +424,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) singleRowDF.registerTempTable("single") - sqlContext.sql( + spark.sql( s"""INSERT INTO TABLE dummy_orc |SELECT key, value FROM single """.stripMargin) - val df = sqlContext.sql("SELECT * FROM dummy_orc WHERE key=0") + val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0") checkAnswer(df, singleRowDF) val queryExecution = df.queryExecution @@ -448,7 +448,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(i => Tuple1(Array(i))) withOrcFile(data) { file => - val actual = sqlContext + val actual = spark .read .orc(file) .where("_1 is not null") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 637c10611a..aba60da33f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -49,7 +49,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.orc(path))) + withOrcFile(data)(path => f(spark.read.orc(path))) } /** @@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + spark.wrapped.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 1c1f6d910d..6e93bbde26 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -634,7 +634,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """.stripMargin) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Row("1st", "2nd", Seq(Row("val_a", "val_b")))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index a3e7737a7c..8bf6f224a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -105,7 +105,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } // Read the bucket file into a dataframe, so that it's easier to test. - val readBack = sqlContext.read.format(source) + val readBack = spark.read.format(source) .load(bucketFile.getAbsolutePath) .select(columns: _*) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index 08e83b7f69..f9387fae4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -34,7 +34,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) + val df = spark.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } @@ -49,7 +49,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton withTempPath { file => // fail the job in the middle of writing val divideByZero = udf((x: Int) => { x / (x - 1)}) - val df = sqlContext.range(0, 10).coalesce(1).select(divideByZero(col("id"))) + val df = spark.range(0, 10).coalesce(1).select(divideByZero(col("id"))) SimpleTextRelation.callbackCalled = false intercept[SparkException] { @@ -66,7 +66,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton SimpleTextRelation.failCommitter = false withTempPath { file => // fail the job in the middle of writing - val df = sqlContext.range(0, 10).coalesce(1).select(col("id").mod(2).as("key"), col("id")) + val df = spark.range(0, 10).coalesce(1).select(col("id").mod(2).as("key"), col("id")) SimpleTextRelation.callbackCalled = false SimpleTextRelation.failWriter = true diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 20c5f72ff1..f4d63334b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { - import sqlContext.implicits._ + import spark.implicits._ val dataSourceName: String @@ -143,8 +143,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .add("index", IntegerType, nullable = false) .add("col", dataType, nullable = true) val rdd = - sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) df.write .mode("overwrite") @@ -153,7 +153,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .options(extraOptions) .save(path) - val loadedDF = sqlContext + val loadedDF = spark .read .format(dataSourceName) .option("dataSchema", df.schema.json) @@ -174,7 +174,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -188,7 +188,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.union(testDF).orderBy("a").collect()) @@ -208,7 +208,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.listStatus(path).isEmpty) } } @@ -222,7 +222,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkQueries( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -243,7 +243,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -265,7 +265,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.union(partitionedTestDF).collect()) @@ -287,7 +287,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -323,7 +323,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.collect()) + checkAnswer(spark.table("t"), testDF.collect()) } } @@ -332,7 +332,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.union(testDF).orderBy("a").collect()) + checkAnswer(spark.table("t"), testDF.union(testDF).orderBy("a").collect()) } } @@ -351,7 +351,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes withTempTable("t") { testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(sqlContext.table("t").collect().isEmpty) + assert(spark.table("t").collect().isEmpty) } } @@ -362,18 +362,18 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkQueries(sqlContext.table("t")) + checkQueries(spark.table("t")) } } test("saveAsTable()/load() - partitioned table - boolean type") { - sqlContext.range(2) + spark.range(2) .select('id, ('id % 2 === 0).as("b")) .write.partitionBy("b").saveAsTable("t") withTable("t") { checkAnswer( - sqlContext.table("t").sort('id), + spark.table("t").sort('id), Row(0, true) :: Row(1, false) :: Nil ) } @@ -395,7 +395,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + checkAnswer(spark.table("t"), partitionedTestDF.collect()) } } @@ -415,7 +415,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.union(partitionedTestDF).collect()) + checkAnswer(spark.table("t"), partitionedTestDF.union(partitionedTestDF).collect()) } } @@ -435,7 +435,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + checkAnswer(spark.table("t"), partitionedTestDF.collect()) } } @@ -484,7 +484,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .partitionBy("p1", "p2") .saveAsTable("t") - assert(sqlContext.table("t").collect().isEmpty) + assert(spark.table("t").collect().isEmpty) } } @@ -516,7 +516,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes // Inferring schema should throw error as it should not find any file to infer val e = intercept[Exception] { - sqlContext.read.format(dataSourceName).load(dir.getCanonicalPath) + spark.read.format(dataSourceName).load(dir.getCanonicalPath) } e match { @@ -533,7 +533,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes /** Test whether data is read with the given path matches the expected answer */ def testWithPath(path: File, expectedAnswer: Seq[Row]): Unit = { - val df = sqlContext.read + val df = spark.read .format(dataSourceName) .schema(dataInDir.schema) // avoid schema inference for any format .load(path.getCanonicalPath) @@ -618,7 +618,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes /** Check whether data is read with the given path matches the expected answer */ def check(path: String, expectedDf: DataFrame): Unit = { - val df = sqlContext.read + val df = spark.read .format(dataSourceName) .schema(schema) // avoid schema inference for any format, expected to be same format .load(path) @@ -654,7 +654,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes basePath: Option[String] = None ): Unit = { try { - val reader = sqlContext.read + val reader = spark.read basePath.foreach(reader.option("basePath", _)) val testDf = reader .format(dataSourceName) @@ -739,7 +739,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val realData = input.collect() - checkAnswer(sqlContext.table("t"), realData ++ realData) + checkAnswer(spark.table("t"), realData ++ realData) } } } @@ -754,7 +754,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect()) + checkAnswer(spark.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect()) } } @@ -766,7 +766,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("SPARK-8406: Avoids name collision while writing files") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext + spark .range(10000) .repartition(250) .write @@ -775,7 +775,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(path) assertResult(10000) { - sqlContext + spark .read .format(dataSourceName) .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) @@ -794,7 +794,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes classOf[AlwaysFailParquetOutputCommitter].getName ) - val df = sqlContext.range(1, 10).toDF("i") + val df = spark.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) // Because there data already exists, @@ -802,7 +802,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes // with file format and AlwaysFailOutputCommitter will not be used. df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) checkAnswer( - sqlContext.read + spark.read .format(dataSourceName) .option("dataSchema", df.schema.json) .options(extraOptions) @@ -850,12 +850,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes ) withTempPath { dir => val path = "file://" + dir.getCanonicalPath - val df1 = sqlContext.range(4) + val df1 = spark.range(4) df1.coalesce(1).write.mode("overwrite").options(options).format(dataSourceName).save(path) df1.coalesce(1).write.mode("append").options(options).format(dataSourceName).save(path) def checkLocality(): Unit = { - val df2 = sqlContext.read + val df2 = spark.read .format(dataSourceName) .option("dataSchema", df1.schema.json) .options(options) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 1d104889fe..4b4852c1d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -126,18 +126,18 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5).toDF() + val df = spark.range(0, 5).toDF() df.write.mode(SaveMode.Overwrite).parquet(path) val summaryPath = new Path(path, "_metadata") val commonSummaryPath = new Path(path, "_common_metadata") - val fs = summaryPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = summaryPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.delete(summaryPath, true) fs.delete(commonSummaryPath, true) df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.union(df)) + checkAnswer(spark.read.parquet(path), df.union(df)) assert(fs.exists(summaryPath)) assert(fs.exists(commonSummaryPath)) @@ -148,8 +148,8 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) - val df = sqlContext.read.parquet(path).filter('a === 0).select('b) + spark.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) + val df = spark.read.parquet(path).filter('a === 0).select('b) val physicalPlan = df.queryExecution.sparkPlan assert(physicalPlan.collect { case p: execution.ProjectExec => p }.length === 1) @@ -170,7 +170,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // The schema consists of the leading columns of the first part-file // in the lexicographic order. - assert(sqlContext.read.parquet(dir.getCanonicalPath).schema.map(_.name) + assert(spark.read.parquet(dir.getCanonicalPath).schema.map(_.name) === Seq("a", "b", "c", "d", "part")) } } @@ -188,8 +188,8 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { Row(5, 127.toByte), Row(6, -44.toByte), Row(7, 23.toByte), Row(8, -95.toByte), Row(9, 127.toByte), Row(10, 13.toByte)) - val rdd = sqlContext.sparkContext.parallelize(data) - val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + val rdd = spark.sparkContext.parallelize(data) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) df.write .mode("overwrite") @@ -197,7 +197,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .option("dataSchema", df.schema.json) .save(path) - val loadedDF = sqlContext + val loadedDF = spark .read .format(dataSourceName) .option("dataSchema", df.schema.json) @@ -221,7 +221,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val compressedFiles = new File(path).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".gz.parquet"))) - val copyDf = sqlContext + val copyDf = spark .read .parquet(path) checkAnswer(df, copyDf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 9ad0887609..fa64c7dcfa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -69,7 +69,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test("test hadoop conf option propagation") { withTempPath { file => // Test write side - val df = sqlContext.range(10).selectExpr("cast(id as string)") + val df = spark.range(10).selectExpr("cast(id as string)") df.write .option("some-random-write-option", "hahah-WRITE") .option("some-null-value-option", null) // test null robustness @@ -78,7 +78,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat assert(SimpleTextRelation.lastHadoopConf.get.get("some-random-write-option") == "hahah-WRITE") // Test read side - val df1 = sqlContext.read + val df1 = spark.read .option("some-random-read-option", "hahah-READ") .option("some-null-value-option", null) // test null robustness .option("dataSchema", df.schema.json)