From 1c751fcf488189e5176546fe0d00f560ffcf1cec Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 11 Apr 2016 09:28:28 -0700 Subject: [PATCH] [SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs ## What changes were proposed in this pull request? This PR updates MLlib APIs to accept `Dataset[_]` as input where `DataFrame` was the input type. This PR doesn't change the output type. In Java, `Dataset[_]` maps to `Dataset`, which includes `Dataset`. Some implementations were changed in order to return `DataFrame`. Tests and examples were updated. Note that this is a breaking change for subclasses of Transformer/Estimator. Lol, we don't have to rename the input argument, which has been `dataset` since Spark 1.2. TODOs: - [x] update MiMaExcludes (seems all covered by explicit filters from SPARK-13920) - [x] Python - [x] add a new test to accept Dataset[LabeledPoint] - [x] remove unused imports of Dataset ## How was this patch tested? Exiting unit tests with some modifications. cc: rxin jkbradley Author: Xiangrui Meng Closes #12274 from mengxr/SPARK-14500. --- .../examples/ml/JavaDeveloperApiExample.java | 2 +- .../examples/ml/DeveloperApiExample.scala | 4 ++-- .../scala/org/apache/spark/ml/Estimator.scala | 16 ++++++++----- .../scala/org/apache/spark/ml/Pipeline.scala | 12 +++++----- .../scala/org/apache/spark/ml/Predictor.scala | 14 +++++------ .../org/apache/spark/ml/Transformer.scala | 15 +++++++----- .../spark/ml/classification/Classifier.scala | 6 ++--- .../DecisionTreeClassifier.scala | 4 ++-- .../ml/classification/GBTClassifier.scala | 6 ++--- .../classification/LogisticRegression.scala | 8 +++---- .../MultilayerPerceptronClassifier.scala | 4 ++-- .../spark/ml/classification/NaiveBayes.scala | 4 ++-- .../spark/ml/classification/OneVsRest.scala | 10 ++++---- .../ProbabilisticClassifier.scala | 6 ++--- .../RandomForestClassifier.scala | 6 ++--- .../spark/ml/clustering/BisectingKMeans.scala | 8 +++---- .../spark/ml/clustering/GaussianMixture.scala | 6 ++--- .../apache/spark/ml/clustering/KMeans.scala | 14 +++++------ .../org/apache/spark/ml/clustering/LDA.scala | 24 +++++++++---------- .../BinaryClassificationEvaluator.scala | 6 ++--- .../spark/ml/evaluation/Evaluator.scala | 10 ++++---- .../MulticlassClassificationEvaluator.scala | 6 ++--- .../ml/evaluation/RegressionEvaluator.scala | 6 ++--- .../apache/spark/ml/feature/Binarizer.scala | 3 ++- .../apache/spark/ml/feature/Bucketizer.scala | 3 ++- .../spark/ml/feature/ChiSqSelector.scala | 6 +++-- .../spark/ml/feature/CountVectorizer.scala | 8 ++++--- .../apache/spark/ml/feature/HashingTF.scala | 5 ++-- .../org/apache/spark/ml/feature/IDF.scala | 6 +++-- .../apache/spark/ml/feature/Interaction.scala | 6 ++--- .../spark/ml/feature/MaxAbsScaler.scala | 6 +++-- .../spark/ml/feature/MinMaxScaler.scala | 6 +++-- .../spark/ml/feature/OneHotEncoder.scala | 5 ++-- .../org/apache/spark/ml/feature/PCA.scala | 6 +++-- .../ml/feature/QuantileDiscretizer.scala | 13 ++++++---- .../apache/spark/ml/feature/RFormula.scala | 18 +++++++------- .../spark/ml/feature/SQLTransformer.scala | 6 ++--- .../spark/ml/feature/StandardScaler.scala | 6 +++-- .../spark/ml/feature/StopWordsRemover.scala | 5 ++-- .../spark/ml/feature/StringIndexer.scala | 13 ++++++---- .../spark/ml/feature/VectorAssembler.scala | 7 +++--- .../spark/ml/feature/VectorIndexer.scala | 8 ++++--- .../spark/ml/feature/VectorSlicer.scala | 5 ++-- .../apache/spark/ml/feature/Word2Vec.scala | 8 ++++--- .../ml/r/AFTSurvivalRegressionWrapper.scala | 4 ++-- .../org/apache/spark/ml/r/KMeansWrapper.scala | 4 ++-- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 4 ++-- .../apache/spark/ml/recommendation/ALS.scala | 10 ++++---- .../ml/regression/AFTSurvivalRegression.scala | 12 +++++----- .../ml/regression/DecisionTreeRegressor.scala | 11 +++++---- .../spark/ml/regression/GBTRegressor.scala | 6 ++--- .../GeneralizedLinearRegression.scala | 4 ++-- .../ml/regression/IsotonicRegression.scala | 12 +++++----- .../ml/regression/LinearRegression.scala | 6 ++--- .../ml/regression/RandomForestRegressor.scala | 6 ++--- .../spark/ml/tuning/CrossValidator.scala | 12 +++++----- .../ml/tuning/TrainValidationSplit.scala | 11 +++++---- .../apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 12 +++++++--- .../LogisticRegressionSuite.scala | 4 ++-- .../MultilayerPerceptronClassifierSuite.scala | 4 ++-- .../ml/classification/NaiveBayesSuite.scala | 4 ++-- .../ml/classification/OneVsRestSuite.scala | 6 ++--- .../ml/clustering/BisectingKMeansSuite.scala | 4 ++-- .../ml/clustering/GaussianMixtureSuite.scala | 4 ++-- .../spark/ml/clustering/KMeansSuite.scala | 4 ++-- .../apache/spark/ml/clustering/LDASuite.scala | 4 ++-- .../apache/spark/ml/feature/NGramSuite.scala | 4 ++-- .../ml/feature/StopWordsRemoverSuite.scala | 4 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- .../spark/ml/feature/TokenizerSuite.scala | 4 ++-- .../GeneralizedLinearRegressionSuite.scala | 8 +++++++ .../spark/ml/tuning/CrossValidatorSuite.scala | 8 +++---- .../ml/tuning/TrainValidationSplitSuite.scala | 6 ++--- .../spark/ml/util/DefaultReadWriteTest.scala | 4 ++-- 75 files changed, 296 insertions(+), 240 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index fbd8817669..0ba94786d4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -146,7 +146,7 @@ class MyJavaLogisticRegression // This method is used by fit(). // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(Dataset dataset) { + public MyJavaLogisticRegressionModel train(Dataset dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index c1f63c6a1d..8d127f9b35 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String) def setMaxIter(value: Int): this.type = set(maxIter, value) // This method is used by fit() - override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { + override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = { // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 57e416591d..1247882d6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * Estimator's embedded ParamMap. * @return fitted model */ + @Since("2.0.0") @varargs - def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { val map = new ParamMap() .put(firstParamPair) .put(otherParamPairs: _*) @@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMap: ParamMap): M = { copy(paramMap).fit(dataset) } /** * Fits a model to the input data. */ - def fit(dataset: DataFrame): M + @Since("2.0.0") + def fit(dataset: Dataset[_]): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index afefaaa883..82066726a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") ( * @param dataset input dataset * @return fitted pipeline */ - @Since("1.2.0") - override def fit(dataset: DataFrame): PipelineModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -291,10 +291,10 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } - @Since("1.2.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) + stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d23ae6f794..81140d1f7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -83,7 +83,7 @@ abstract class Predictor[ /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - override def fit(dataset: DataFrame): M = { + override def fit(dataset: Dataset[_]): M = { // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) @@ -100,7 +100,7 @@ abstract class Predictor[ * @param dataset Training dataset * @return Fitted model */ - protected def train(dataset: DataFrame): M + protected def train(dataset: Dataset[_]): M /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. @@ -120,7 +120,7 @@ abstract class Predictor[ * Extract [[labelCol]] and [[featuresCol]] from the given dataset, * and put it in an RDD with strong types. */ - protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { + protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -171,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @param dataset input dataset * @return transformed dataset with [[predictionCol]] of type [[Double]] */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") - dataset + dataset.toDF } } - protected def transformImpl(dataset: DataFrame): DataFrame = { + protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 2538c0f477..a3a2b55adc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage { * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ + @Since("2.0.0") @varargs def transform( - dataset: DataFrame, + dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() @@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + @Since("2.0.0") + def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = { this.copy(paramMap).transform(dataset) } /** * Transforms the input dataset. */ - def transform(dataset: DataFrame): DataFrame + @Since("2.0.0") + def transform(dataset: Dataset[_]): DataFrame override def copy(extra: ParamMap): Transformer } @@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 8186afc17a..473e801794 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Output selected columns only. @@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 4525bf71f6..300ae4339c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** @@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) - override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index a2150fbcc3..46e8b89d01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -149,7 +149,7 @@ final class GBTClassifier @Since("1.4.0") ( } } - override protected def train(dataset: DataFrame): GBTClassificationModel = { + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -220,7 +220,7 @@ final class GBTClassificationModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 268c3e32c3..4a3fe5c663 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -257,12 +257,12 @@ class LogisticRegression @Since("1.2.0") ( this } - override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE train(dataset, handlePersistence) } - protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): + protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = @@ -544,7 +544,7 @@ class LogisticRegressionModel private[spark] ( * @param dataset Test dataset to evaluate model on. */ @Since("2.0.0") - def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 79bb2a8855..9ff5252e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTo import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** Params for Multilayer Perceptron. */ private[ml] trait MultilayerPerceptronParams extends PredictorParams @@ -199,7 +199,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { + override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 483ef0d88c..267d63b51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** * Params for Naive Bayes Classifiers. @@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") ( def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) - override protected def train(dataset: DataFrame): NaiveBayesModel = { + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 263d54ce4d..4de1b877b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -140,8 +140,8 @@ final class OneVsRestModel private[ml] ( validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -293,8 +293,8 @@ final class OneVsRest @Since("1.4.0") ( validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } - @Since("1.4.0") - override def fit(dataset: DataFrame): OneVsRestModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): OneVsRestModel = { transformSchema(dataset.schema) // determine number of classes either from metadata if provided, or via computation. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 865614aa5c..d00fee12b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[ * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index cb42532271..9d80b8eb68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -98,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestClassificationModel = { + override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -180,7 +180,7 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 55f751c57f..6cc9117da3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( * centers. */ @Since("2.0.0") - def computeCost(dataset: DataFrame): Double = { + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") ( def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) @Since("2.0.0") - override def fit(dataset: DataFrame): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 120bf3cf9d..ead8ad7806 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -80,7 +80,7 @@ class GaussianMixtureModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) @@ -238,7 +238,7 @@ class GaussianMixture @Since("2.0.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: DataFrame): GaussianMixtureModel = { + override def fit(dataset: Dataset[_]): GaussianMixtureModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibGM() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a8beef8b12..d716bc6887 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -105,8 +105,8 @@ class KMeansModel private[ml] ( copyValues(copied, extra) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -126,8 +126,8 @@ class KMeansModel private[ml] ( * model on the given data. */ // TODO: Replace the temp fix when we have proper evaluators defined for clustering. - @Since("1.6.0") - def computeCost(dataset: DataFrame): Double = { + @Since("2.0.0") + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -254,8 +254,8 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): KMeansModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibKMeans() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 89a7a4ccf6..c57ceba4a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType @@ -402,15 +402,15 @@ sealed abstract class LDAModel private[ml] ( * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. * This implementation may be changed in the future. */ - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if ($(topicDistributionCol).nonEmpty) { val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset + dataset.toDF } } @@ -455,8 +455,8 @@ sealed abstract class LDAModel private[ml] ( * @param dataset test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus */ - @Since("1.6.0") - def logLikelihood(dataset: DataFrame): Double = { + @Since("2.0.0") + def logLikelihood(dataset: Dataset[_]): Double = { val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logLikelihood(oldDataset) } @@ -472,8 +472,8 @@ sealed abstract class LDAModel private[ml] ( * @param dataset test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. */ - @Since("1.6.0") - def logPerplexity(dataset: DataFrame): Double = { + @Since("2.0.0") + def logPerplexity(dataset: Dataset[_]): Double = { val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logPerplexity(oldDataset) } @@ -840,8 +840,8 @@ class LDA @Since("1.6.0") ( @Since("1.6.0") override def copy(extra: ParamMap): LDA = defaultCopy(extra) - @Since("1.6.0") - override def fit(dataset: DataFrame): LDAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): LDAModel = { transformSchema(dataset.schema, logging = true) val oldLDA = new OldLDA() .setK($(k)) @@ -873,7 +873,7 @@ class LDA @Since("1.6.0") ( private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ - def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { + def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = { dataset .withColumn("docId", monotonicallyIncreasingId()) .select("docId", featuresCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 337ffbe90f..bde8c275fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va setDefault(metricName -> "areaUnderROC") - @Since("1.2.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 0f22cca3a7..5f765c071b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, Params} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -36,8 +36,8 @@ abstract class Evaluator extends Params { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + @Since("2.0.0") + def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = { this.copy(paramMap).evaluate(dataset) } @@ -46,8 +46,8 @@ abstract class Evaluator extends Params { * @param dataset a dataset that contains labels/observations and predictions. * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame): Double + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): Double /** * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 55ff44323a..3acfc221c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid setDefault(metricName -> "f1") - @Since("1.5.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 9976d7ed43..4134e2dbc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType} @@ -70,8 +70,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui setDefault(metricName -> "rmse") - @Since("1.4.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2f8e3a0371..898ac2cc89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -64,7 +64,8 @@ final class Binarizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 33abc7c99d..10e622ace6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val bucketizer = udf { feature: Double => Bucketizer.binarySearchForBuckets($(splits), feature) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index b9e9d56853..cfecae7e0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String) /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def fit(dataset: DataFrame): ChiSqSelectorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { case Row(label: Double, features: Vector) => @@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] ( /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last val selector = udf { chiSqSelector.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 00abbbe29c..922670a41b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -147,7 +147,8 @@ class CountVectorizer(override val uid: String) setDefault(vocabSize -> (1 << 18), minDF -> 1) - override def fit(dataset: DataFrame): CountVectorizerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) @@ -224,7 +225,8 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0f7ae5a100..467ad73074 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidat import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} @@ -77,7 +77,8 @@ class HashingTF(override val uid: String) /** @group setParam */ def setBinary(value: Boolean): this.type = set(binary, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) val t = udf { terms: Seq[_] => hashingTF.transform(terms) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f36cf503a0..5075b78c98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - override def fit(dataset: DataFrame): IDFModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) @@ -115,7 +116,8 @@ class IDFModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val idf = udf { vec: Vector => idfModel.transform(vec) } dataset.withColumn($(outputCol), idf(col($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index d3fe6e528f..9ca34e9ae2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 7de5a4d5d3..e9df600c8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): MaxAbsScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b13684a1cb..125becbb8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String) /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def fit(dataset: DataFrame): MinMaxScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4f67042629..99357793db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} @@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 305c3d187f..9cf722e121 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. */ - override def fit(dataset: DataFrame): PCAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) @@ -124,7 +125,8 @@ class PCAModel private[ml] ( * NOTE: Vectors to be transformed must be the same length * as the source vectors given to [[PCA.fit()]]. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val pcaModel = new feature.PCAModel($(k), pc, explainedVariance) val pcaOp = udf { pcaModel.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index e486e92c12..efe8b93d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom @@ -87,7 +87,8 @@ final class QuantileDiscretizer(override val uid: String) StructType(outputFields) } - override def fit(dataset: DataFrame): Bucketizer = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Bucketizer = { val samples = QuantileDiscretizer .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) .map { case Row(feature: Double) => feature } @@ -112,13 +113,15 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi /** * Sampling from the given dataset to collect quantile statistics. */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { + private[feature] + def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") val requiredSamples = math.max(numBins * numBins, minSamplesRequired) val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() + dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) + .collect() } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 12a76dbbfb..3ac6c77669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types._ /** @@ -103,7 +103,8 @@ class RFormula(override val uid: String) RFormulaParser.parse($(formula)).hasIntercept } - override def fit(dataset: DataFrame): RFormulaModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): RFormulaModel = { require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) @@ -204,7 +205,8 @@ class RFormulaModel private[feature]( private[ml] val pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase with MLWritable { - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { checkCanTransform(dataset.schema) transformLabel(pipelineModel.transform(dataset)) } @@ -232,10 +234,10 @@ class RFormulaModel private[feature]( override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" - private def transformLabel(dataset: DataFrame): DataFrame = { + private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { - dataset + dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => @@ -246,7 +248,7 @@ class RFormulaModel private[feature]( } else { // Ignore the label field. This is a hack so that this transformer can also work on test // datasets in a Pipeline. - dataset + dataset.toDF } } @@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str def this(columnsToPrune: Set[String]) = this(Identifiable.randomUID("columnPruner"), columnsToPrune) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) dataset.select(columnsToKeep.map(dataset.col): _*) } @@ -396,7 +398,7 @@ private class VectorAttributeRewriter( def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val metadata = { val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index e0ca45b9a6..95fe942c6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.types.StructType /** @@ -63,8 +63,8 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor private val tableIdentifier: String = "__THIS__" - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 26ee8e1bf1..118a6e3e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - override def fit(dataset: DataFrame): StandardScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) @@ -135,7 +136,8 @@ class StandardScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0a0e0b0960..b96bc48566 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructType} @@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String) setDefault(stopWords -> StopWords.English, caseSensitive -> false) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index faa0f6f407..7e0d374f02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): StringIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -144,11 +145,12 @@ class StringIndexerModel ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + "Skip StringIndexerModel.") - return dataset + return dataset.toDF } validateAndTransformSchema(dataset.schema) @@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String) StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if ($(labels).isEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 957e8e7a59..4d3e46e488 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Schema transformation. val schema = dataset.schema - lazy val first = dataset.first() + lazy val first = dataset.toDF.first() val attrs = $(inputCols).flatMap { c => val field = schema(c) val index = schema.fieldIndex(c) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index bf4aef2a74..68b699d569 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): VectorIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") @@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index b60e82de00..7a9468b87b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType @@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Validity checks transformSchema(dataset.schema) val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 95bae1c8a3..a72692960f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) - override def fit(dataset: DataFrame): Word2VecModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() @@ -219,7 +220,8 @@ class Word2VecModel private[ml] ( * Transform a sentence column to a vector column to represent the whole sentence. The transform * is performed by averaging all word vectors it contains. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 40590e71c4..2ae411555f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class AFTSurvivalRegressionWrapper private ( pipeline: PipelineModel, @@ -43,7 +43,7 @@ private[r] class AFTSurvivalRegressionWrapper private ( features ++ Array("Log(scale)") } - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index ed735a4ea3..ee513579ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class KMeansWrapper private ( pipeline: PipelineModel) { @@ -52,7 +52,7 @@ private[r] class KMeansWrapper private ( } } - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 07383d393d..2cd709d2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class NaiveBayesWrapper private ( pipeline: PipelineModel, @@ -36,7 +36,7 @@ private[r] class NaiveBayesWrapper private ( lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4a3ad662a0..36dce01590 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -200,8 +200,8 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - @Since("1.3.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] this } - @Since("1.3.0") - override def fit(dataset: DataFrame): ALSModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ALSModel = { import dataset.sqlContext.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3278974954..afed1f32b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -32,7 +32,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -183,7 +183,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. */ - protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) .rdd.map { case Row(features: Vector, label: Double, censor: Double) => @@ -191,8 +191,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } } - @Since("1.6.0") - override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -299,8 +299,8 @@ class AFTSurvivalRegressionModel private[ml] ( math.exp(BLAS.dot(coefficients, features) + intercept) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 1289a317ee..c04c416aaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val /** @group setParam */ def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) transformImpl(dataset) } - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset + var output = dataset.toDF if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8eb2984f7b..0b52fe2d13 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -147,7 +147,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri } } - override protected def train(dataset: DataFrame): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -209,7 +209,7 @@ final class GBTRegressionModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 05bf64591b..00cf25dc54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -196,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "irls") - override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyObj = Family.fromName($(family)) val linkObj = if (isDefined(link)) { Link.fromName($(link)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index bd0b631d89..7a78ecbdf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Extracts (label, feature, weight) from input dataset. */ protected[ml] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + dataset: Dataset[_]): RDD[(Double, Double, Double)] = { val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) val extract = udf { v: Vector => v(idx) } @@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - @Since("1.5.0") - override def fit(dataset: DataFrame): IsotonicRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] ( copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index aacff4ea47..71e02730c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -38,7 +38,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "auto") - override protected def train(dataset: DataFrame): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size @@ -417,7 +417,7 @@ class LinearRegressionModel private[ml] ( * @param dataset Test dataset to evaluate model on. */ @Since("2.0.0") - def evaluate(dataset: DataFrame): LinearRegressionSummary = { + def evaluate(dataset: Dataset[_]): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 736cd9f776..bee13c2ebf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -93,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -164,7 +164,7 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4d9d4d472e..de563d4fad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -90,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.4.0") - override def fit(dataset: DataFrame): CrossValidatorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -100,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -209,8 +209,8 @@ class CrossValidatorModel private[ml] ( this(uid, bestModel, avgMetrics.asScala.toArray) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0f2179c2a1..12d6905510 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.tuning import java.util.{List => JList} import scala.collection.JavaConverters._ +import scala.language.existentials import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -31,7 +32,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -89,8 +90,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): TrainValidationSplitModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -207,8 +208,8 @@ class TrainValidationSplitModel private[ml] ( this(uid, bestModel, validationMetrics.asScala.toArray) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0f0c3a2df5..5812cdde2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -186,7 +186,7 @@ sealed trait Vector extends Serializable { * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.DataFrame]]. + * via [[org.apache.spark.sql.Dataset]]. */ @AlphaComponent class VectorUDT extends UserDefinedType[Vector] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index f3321fb5a1..a8c4ac6d05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] + when(dataset0.toDF).thenReturn(dataset0) + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) when(model0.copy(any[ParamMap])).thenReturn(model0) when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) @@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with MLWritabl override def write: MLWriter = new DefaultParamsWriter(this) - override def transform(dataset: DataFrame): DataFrame = dataset + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF override def transformSchema(schema: StructType): StructType = schema } @@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer { override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) - override def transform(dataset: DataFrame): DataFrame = dataset + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 7eefaf2346..48db428130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 06ff049b48..80547fad6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -26,12 +26,12 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 4727cd436f..80a46fc70c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -27,11 +27,11 @@ import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 4131396726..f3e8fd11b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var rdd: RDD[LabeledPoint] = _ override def beforeAll(): Unit = { @@ -246,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 18f2c994b4..e641d79c17 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 8edd44e5f1..1a274aea29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c684bc11cc..2076c745e2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a1c93891c7..ee8eae8f69 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} object LDASuite { @@ -64,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val k: Int = 5 val vocabSize: Int = 30 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index 58fda29aa1..e4e15f4331 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -22,7 +22,7 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) @@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: DataFrame): Unit = { + def testNGram(t: NGram, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("nGrams", "wantedNGrams") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index a5b24c1856..3505befdf8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("filtered", "expected") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2c3255ef33..d0f3cdc841 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -115,7 +115,7 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = sqlContext.range(0L, 10L).toDF() - assert(indexerModel.transform(df).eq(df)) + assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) } test("StringIndexerModel can't overwrite output column") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index 36e8e5d868..299f6223b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) @@ -106,7 +106,7 @@ class RegexTokenizerSuite object RegexTokenizerSuite extends SparkFunSuite { - def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 2265464b51..4905f3e068 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -992,6 +992,14 @@ class GeneralizedLinearRegressionSuite assert(expected.coefficients === actual.coefficients) } } + + test("glm accepts Dataset[LabeledPoint]") { + val context = sqlContext + import context.implicits._ + new GeneralizedLinearRegression() + .setFamily("gaussian") + .fit(datasetGaussianIdentity.as[LabeledPoint]) + } } object GeneralizedLinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7af3c6d6ed..3e734aabc5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.{StructField, StructType} class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } @@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4030956fab..dbee47c847 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite @@ -158,7 +158,7 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } @@ -172,7 +172,7 @@ object TrainValidationSplitSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 16280473c6..7ebd7eb144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} trait DefaultReadWriteTest extends TempDirectory { self: Suite => @@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, - dataset: DataFrame, + dataset: Dataset[_], testParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized.