[SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs
## What changes were proposed in this pull request? This PR updates MLlib APIs to accept `Dataset[_]` as input where `DataFrame` was the input type. This PR doesn't change the output type. In Java, `Dataset[_]` maps to `Dataset<?>`, which includes `Dataset<Row>`. Some implementations were changed in order to return `DataFrame`. Tests and examples were updated. Note that this is a breaking change for subclasses of Transformer/Estimator. Lol, we don't have to rename the input argument, which has been `dataset` since Spark 1.2. TODOs: - [x] update MiMaExcludes (seems all covered by explicit filters from SPARK-13920) - [x] Python - [x] add a new test to accept Dataset[LabeledPoint] - [x] remove unused imports of Dataset ## How was this patch tested? Exiting unit tests with some modifications. cc: rxin jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #12274 from mengxr/SPARK-14500.
This commit is contained in:
parent
e82d95bf63
commit
1c751fcf48
|
@ -146,7 +146,7 @@ class MyJavaLogisticRegression
|
||||||
|
|
||||||
// This method is used by fit().
|
// This method is used by fit().
|
||||||
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
|
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
|
||||||
public MyJavaLogisticRegressionModel train(Dataset<Row> dataset) {
|
public MyJavaLogisticRegressionModel train(Dataset<?> dataset) {
|
||||||
// Extract columns from data using helper method.
|
// Extract columns from data using helper method.
|
||||||
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
|
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util.Identifiable
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple example demonstrating how to write your own learning algorithm using Estimator,
|
* A simple example demonstrating how to write your own learning algorithm using Estimator,
|
||||||
|
@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String)
|
||||||
def setMaxIter(value: Int): this.type = set(maxIter, value)
|
def setMaxIter(value: Int): this.type = set(maxIter, value)
|
||||||
|
|
||||||
// This method is used by fit()
|
// This method is used by fit()
|
||||||
override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
|
override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = {
|
||||||
// Extract columns from data using helper method.
|
// Extract columns from data using helper method.
|
||||||
val oldDataset = extractLabeledPoints(dataset)
|
val oldDataset = extractLabeledPoints(dataset)
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,9 @@ package org.apache.spark.ml
|
||||||
|
|
||||||
import scala.annotation.varargs
|
import scala.annotation.varargs
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
import org.apache.spark.ml.param.{ParamMap, ParamPair}
|
import org.apache.spark.ml.param.{ParamMap, ParamPair}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.Dataset
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
|
@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
|
||||||
* Estimator's embedded ParamMap.
|
* Estimator's embedded ParamMap.
|
||||||
* @return fitted model
|
* @return fitted model
|
||||||
*/
|
*/
|
||||||
|
@Since("2.0.0")
|
||||||
@varargs
|
@varargs
|
||||||
def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
|
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
|
||||||
val map = new ParamMap()
|
val map = new ParamMap()
|
||||||
.put(firstParamPair)
|
.put(firstParamPair)
|
||||||
.put(otherParamPairs: _*)
|
.put(otherParamPairs: _*)
|
||||||
|
@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
|
||||||
* These values override any specified in this Estimator's embedded ParamMap.
|
* These values override any specified in this Estimator's embedded ParamMap.
|
||||||
* @return fitted model
|
* @return fitted model
|
||||||
*/
|
*/
|
||||||
def fit(dataset: DataFrame, paramMap: ParamMap): M = {
|
@Since("2.0.0")
|
||||||
|
def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
|
||||||
copy(paramMap).fit(dataset)
|
copy(paramMap).fit(dataset)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fits a model to the input data.
|
* Fits a model to the input data.
|
||||||
*/
|
*/
|
||||||
def fit(dataset: DataFrame): M
|
@Since("2.0.0")
|
||||||
|
def fit(dataset: Dataset[_]): M
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fits multiple models to the input data with multiple sets of parameters.
|
* Fits multiple models to the input data with multiple sets of parameters.
|
||||||
|
@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
|
||||||
* These values override any specified in this Estimator's embedded ParamMap.
|
* These values override any specified in this Estimator's embedded ParamMap.
|
||||||
* @return fitted models, matching the input parameter maps
|
* @return fitted models, matching the input parameter maps
|
||||||
*/
|
*/
|
||||||
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
|
@Since("2.0.0")
|
||||||
|
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
|
||||||
paramMaps.map(fit(dataset, _))
|
paramMaps.map(fit(dataset, _))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.param.{Param, ParamMap, Params}
|
import org.apache.spark.ml.param.{Param, ParamMap, Params}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") (
|
||||||
* @param dataset input dataset
|
* @param dataset input dataset
|
||||||
* @return fitted pipeline
|
* @return fitted pipeline
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): PipelineModel = {
|
override def fit(dataset: Dataset[_]): PipelineModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val theStages = $(stages)
|
val theStages = $(stages)
|
||||||
// Search for the last estimator.
|
// Search for the last estimator.
|
||||||
|
@ -291,10 +291,10 @@ class PipelineModel private[ml] (
|
||||||
this(uid, stages.asScala.toArray)
|
this(uid, stages.asScala.toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.2.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
|
stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ abstract class Predictor[
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
|
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): M = {
|
override def fit(dataset: Dataset[_]): M = {
|
||||||
// This handles a few items such as schema validation.
|
// This handles a few items such as schema validation.
|
||||||
// Developers only need to implement train().
|
// Developers only need to implement train().
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
|
@ -100,7 +100,7 @@ abstract class Predictor[
|
||||||
* @param dataset Training dataset
|
* @param dataset Training dataset
|
||||||
* @return Fitted model
|
* @return Fitted model
|
||||||
*/
|
*/
|
||||||
protected def train(dataset: DataFrame): M
|
protected def train(dataset: Dataset[_]): M
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
|
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
|
||||||
|
@ -120,7 +120,7 @@ abstract class Predictor[
|
||||||
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
|
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
|
||||||
* and put it in an RDD with strong types.
|
* and put it in an RDD with strong types.
|
||||||
*/
|
*/
|
||||||
protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
|
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
|
||||||
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
|
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
|
||||||
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
|
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
|
||||||
}
|
}
|
||||||
|
@ -171,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
|
||||||
* @param dataset input dataset
|
* @param dataset input dataset
|
||||||
* @return transformed dataset with [[predictionCol]] of type [[Double]]
|
* @return transformed dataset with [[predictionCol]] of type [[Double]]
|
||||||
*/
|
*/
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
if ($(predictionCol).nonEmpty) {
|
if ($(predictionCol).nonEmpty) {
|
||||||
transformImpl(dataset)
|
transformImpl(dataset)
|
||||||
} else {
|
} else {
|
||||||
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
|
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
|
||||||
" since no output columns were set.")
|
" since no output columns were set.")
|
||||||
dataset
|
dataset.toDF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def transformImpl(dataset: DataFrame): DataFrame = {
|
protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||||
val predictUDF = udf { (features: Any) =>
|
val predictUDF = udf { (features: Any) =>
|
||||||
predict(features.asInstanceOf[FeaturesType])
|
predict(features.asInstanceOf[FeaturesType])
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,11 +19,11 @@ package org.apache.spark.ml
|
||||||
|
|
||||||
import scala.annotation.varargs
|
import scala.annotation.varargs
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage {
|
||||||
* @param otherParamPairs other param pairs, overwrite embedded params
|
* @param otherParamPairs other param pairs, overwrite embedded params
|
||||||
* @return transformed dataset
|
* @return transformed dataset
|
||||||
*/
|
*/
|
||||||
|
@Since("2.0.0")
|
||||||
@varargs
|
@varargs
|
||||||
def transform(
|
def transform(
|
||||||
dataset: DataFrame,
|
dataset: Dataset[_],
|
||||||
firstParamPair: ParamPair[_],
|
firstParamPair: ParamPair[_],
|
||||||
otherParamPairs: ParamPair[_]*): DataFrame = {
|
otherParamPairs: ParamPair[_]*): DataFrame = {
|
||||||
val map = new ParamMap()
|
val map = new ParamMap()
|
||||||
|
@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage {
|
||||||
* @param paramMap additional parameters, overwrite embedded params
|
* @param paramMap additional parameters, overwrite embedded params
|
||||||
* @return transformed dataset
|
* @return transformed dataset
|
||||||
*/
|
*/
|
||||||
def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {
|
||||||
this.copy(paramMap).transform(dataset)
|
this.copy(paramMap).transform(dataset)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transforms the input dataset.
|
* Transforms the input dataset.
|
||||||
*/
|
*/
|
||||||
def transform(dataset: DataFrame): DataFrame
|
@Since("2.0.0")
|
||||||
|
def transform(dataset: Dataset[_]): DataFrame
|
||||||
|
|
||||||
override def copy(extra: ParamMap): Transformer
|
override def copy(extra: ParamMap): Transformer
|
||||||
}
|
}
|
||||||
|
@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val transformUDF = udf(this.createTransformFunc, outputDataType)
|
val transformUDF = udf(this.createTransformFunc, outputDataType)
|
||||||
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
|
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
|
||||||
import org.apache.spark.ml.param.shared.HasRawPredictionCol
|
import org.apache.spark.ml.param.shared.HasRawPredictionCol
|
||||||
import org.apache.spark.ml.util.SchemaUtils
|
import org.apache.spark.ml.util.SchemaUtils
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DataType, StructType}
|
import org.apache.spark.sql.types.{DataType, StructType}
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
|
||||||
* @param dataset input dataset
|
* @param dataset input dataset
|
||||||
* @return transformed dataset
|
* @return transformed dataset
|
||||||
*/
|
*/
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
|
|
||||||
// Output selected columns only.
|
// Output selected columns only.
|
||||||
|
@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
|
||||||
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
|
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
|
||||||
" since no output columns were set.")
|
" since no output columns were set.")
|
||||||
}
|
}
|
||||||
outputData
|
outputData.toDF
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def setSeed(value: Long): this.type = super.setSeed(value)
|
override def setSeed(value: Long): this.type = super.setSeed(value)
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
|
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
|
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
|
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
|
||||||
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
|
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -149,7 +149,7 @@ final class GBTClassifier @Since("1.4.0") (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): GBTClassificationModel = {
|
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
|
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
|
||||||
|
@ -220,7 +220,7 @@ final class GBTClassificationModel private[ml](
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def treeWeights: Array[Double] = _treeWeights
|
override def treeWeights: Array[Double] = _treeWeights
|
||||||
|
|
||||||
override protected def transformImpl(dataset: DataFrame): DataFrame = {
|
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||||
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
||||||
val predictUDF = udf { (features: Any) =>
|
val predictUDF = udf { (features: Any) =>
|
||||||
bcastModel.value.predict(features.asInstanceOf[Vector])
|
bcastModel.value.predict(features.asInstanceOf[Vector])
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.apache.spark.mllib.linalg.BLAS._
|
||||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, lit}
|
import org.apache.spark.sql.functions.{col, lit}
|
||||||
import org.apache.spark.sql.types.DoubleType
|
import org.apache.spark.sql.types.DoubleType
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
@ -257,12 +257,12 @@ class LogisticRegression @Since("1.2.0") (
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
|
override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
|
||||||
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
|
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
|
||||||
train(dataset, handlePersistence)
|
train(dataset, handlePersistence)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
|
protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean):
|
||||||
LogisticRegressionModel = {
|
LogisticRegressionModel = {
|
||||||
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||||
val instances: RDD[Instance] =
|
val instances: RDD[Instance] =
|
||||||
|
@ -544,7 +544,7 @@ class LogisticRegressionModel private[spark] (
|
||||||
* @param dataset Test dataset to evaluate model on.
|
* @param dataset Test dataset to evaluate model on.
|
||||||
*/
|
*/
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
|
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
|
||||||
// Handle possible missing or invalid prediction columns
|
// Handle possible missing or invalid prediction columns
|
||||||
val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
|
val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
|
||||||
new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
|
new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTo
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
/** Params for Multilayer Perceptron. */
|
/** Params for Multilayer Perceptron. */
|
||||||
private[ml] trait MultilayerPerceptronParams extends PredictorParams
|
private[ml] trait MultilayerPerceptronParams extends PredictorParams
|
||||||
|
@ -199,7 +199,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
|
||||||
* @param dataset Training dataset
|
* @param dataset Training dataset
|
||||||
* @return Fitted model
|
* @return Fitted model
|
||||||
*/
|
*/
|
||||||
override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
|
override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = {
|
||||||
val myLayers = $(layers)
|
val myLayers = $(layers)
|
||||||
val labels = myLayers.last
|
val labels = myLayers.last
|
||||||
val lpData = extractLabeledPoints(dataset)
|
val lpData = extractLabeledPoints(dataset)
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo
|
||||||
import org.apache.spark.mllib.linalg._
|
import org.apache.spark.mllib.linalg._
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Params for Naive Bayes Classifiers.
|
* Params for Naive Bayes Classifiers.
|
||||||
|
@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") (
|
||||||
def setModelType(value: String): this.type = set(modelType, value)
|
def setModelType(value: String): this.type = set(modelType, value)
|
||||||
setDefault(modelType -> OldNaiveBayes.Multinomial)
|
setDefault(modelType -> OldNaiveBayes.Multinomial)
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): NaiveBayesModel = {
|
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||||
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
|
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
|
||||||
NaiveBayesModel.fromOld(oldModel, this)
|
NaiveBayesModel.fromOld(oldModel, this)
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.ml.attribute._
|
||||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
|
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
@ -140,8 +140,8 @@ final class OneVsRestModel private[ml] (
|
||||||
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
|
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
// Check schema
|
// Check schema
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
|
|
||||||
|
@ -293,8 +293,8 @@ final class OneVsRest @Since("1.4.0") (
|
||||||
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
|
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): OneVsRestModel = {
|
override def fit(dataset: Dataset[_]): OneVsRestModel = {
|
||||||
transformSchema(dataset.schema)
|
transformSchema(dataset.schema)
|
||||||
|
|
||||||
// determine number of classes either from metadata if provided, or via computation.
|
// determine number of classes either from metadata if provided, or via computation.
|
||||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.SchemaUtils
|
import org.apache.spark.ml.util.SchemaUtils
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DataType, StructType}
|
import org.apache.spark.sql.types.{DataType, StructType}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[
|
||||||
* @param dataset input dataset
|
* @param dataset input dataset
|
||||||
* @return transformed dataset
|
* @return transformed dataset
|
||||||
*/
|
*/
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
if (isDefined(thresholds)) {
|
if (isDefined(thresholds)) {
|
||||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||||
|
@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[
|
||||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
||||||
" since no output columns were set.")
|
" since no output columns were set.")
|
||||||
}
|
}
|
||||||
outputData
|
outputData.toDF
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") (
|
||||||
override def setFeatureSubsetStrategy(value: String): this.type =
|
override def setFeatureSubsetStrategy(value: String): this.type =
|
||||||
super.setFeatureSubsetStrategy(value)
|
super.setFeatureSubsetStrategy(value)
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
|
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
|
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
|
||||||
|
@ -180,7 +180,7 @@ final class RandomForestClassificationModel private[ml] (
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def treeWeights: Array[Double] = _treeWeights
|
override def treeWeights: Array[Double] = _treeWeights
|
||||||
|
|
||||||
override protected def transformImpl(dataset: DataFrame): DataFrame = {
|
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||||
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
||||||
val predictUDF = udf { (features: Any) =>
|
val predictUDF = udf { (features: Any) =>
|
||||||
bcastModel.value.predict(features.asInstanceOf[Vector])
|
bcastModel.value.predict(features.asInstanceOf[Vector])
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.clustering.
|
import org.apache.spark.mllib.clustering.
|
||||||
{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
|
{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] (
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val predictUDF = udf((vector: Vector) => predict(vector))
|
val predictUDF = udf((vector: Vector) => predict(vector))
|
||||||
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] (
|
||||||
* centers.
|
* centers.
|
||||||
*/
|
*/
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def computeCost(dataset: DataFrame): Double = {
|
def computeCost(dataset: Dataset[_]): Double = {
|
||||||
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
|
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
|
||||||
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||||
parentModel.computeCost(data)
|
parentModel.computeCost(data)
|
||||||
|
@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") (
|
||||||
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
|
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): BisectingKMeansModel = {
|
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
|
||||||
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||||
|
|
||||||
val bkm = new MLlibBisectingKMeans()
|
val bkm = new MLlibBisectingKMeans()
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
|
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
|
||||||
import org.apache.spark.mllib.linalg._
|
import org.apache.spark.mllib.linalg._
|
||||||
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class GaussianMixtureModel private[ml] (
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val predUDF = udf((vector: Vector) => predict(vector))
|
val predUDF = udf((vector: Vector) => predict(vector))
|
||||||
val probUDF = udf((vector: Vector) => predictProbability(vector))
|
val probUDF = udf((vector: Vector) => predictProbability(vector))
|
||||||
dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
|
dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
|
||||||
|
@ -238,7 +238,7 @@ class GaussianMixture @Since("2.0.0") (
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
def setSeed(value: Long): this.type = set(seed, value)
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): GaussianMixtureModel = {
|
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
|
||||||
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||||
|
|
||||||
val algo = new MLlibGM()
|
val algo = new MLlibGM()
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
|
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||||
|
|
||||||
|
@ -105,8 +105,8 @@ class KMeansModel private[ml] (
|
||||||
copyValues(copied, extra)
|
copyValues(copied, extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val predictUDF = udf((vector: Vector) => predict(vector))
|
val predictUDF = udf((vector: Vector) => predict(vector))
|
||||||
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
||||||
}
|
}
|
||||||
|
@ -126,8 +126,8 @@ class KMeansModel private[ml] (
|
||||||
* model on the given data.
|
* model on the given data.
|
||||||
*/
|
*/
|
||||||
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
|
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
def computeCost(dataset: DataFrame): Double = {
|
def computeCost(dataset: Dataset[_]): Double = {
|
||||||
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
|
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
|
||||||
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||||
parentModel.computeCost(data)
|
parentModel.computeCost(data)
|
||||||
|
@ -254,8 +254,8 @@ class KMeans @Since("1.5.0") (
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
def setSeed(value: Long): this.type = set(seed, value)
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): KMeansModel = {
|
override def fit(dataset: Dataset[_]): KMeansModel = {
|
||||||
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||||
|
|
||||||
val algo = new MLlibKMeans()
|
val algo = new MLlibKMeans()
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
|
||||||
import org.apache.spark.mllib.impl.PeriodicCheckpointer
|
import org.apache.spark.mllib.impl.PeriodicCheckpointer
|
||||||
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
|
||||||
import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
|
import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
@ -402,15 +402,15 @@ sealed abstract class LDAModel private[ml] (
|
||||||
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
|
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
|
||||||
* This implementation may be changed in the future.
|
* This implementation may be changed in the future.
|
||||||
*/
|
*/
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
if ($(topicDistributionCol).nonEmpty) {
|
if ($(topicDistributionCol).nonEmpty) {
|
||||||
val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
|
val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
|
||||||
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
|
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
|
||||||
} else {
|
} else {
|
||||||
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
|
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
|
||||||
" such as topicDistributionCol to produce results.")
|
" such as topicDistributionCol to produce results.")
|
||||||
dataset
|
dataset.toDF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -455,8 +455,8 @@ sealed abstract class LDAModel private[ml] (
|
||||||
* @param dataset test corpus to use for calculating log likelihood
|
* @param dataset test corpus to use for calculating log likelihood
|
||||||
* @return variational lower bound on the log likelihood of the entire corpus
|
* @return variational lower bound on the log likelihood of the entire corpus
|
||||||
*/
|
*/
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
def logLikelihood(dataset: DataFrame): Double = {
|
def logLikelihood(dataset: Dataset[_]): Double = {
|
||||||
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
|
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
|
||||||
oldLocalModel.logLikelihood(oldDataset)
|
oldLocalModel.logLikelihood(oldDataset)
|
||||||
}
|
}
|
||||||
|
@ -472,8 +472,8 @@ sealed abstract class LDAModel private[ml] (
|
||||||
* @param dataset test corpus to use for calculating perplexity
|
* @param dataset test corpus to use for calculating perplexity
|
||||||
* @return Variational upper bound on log perplexity per token.
|
* @return Variational upper bound on log perplexity per token.
|
||||||
*/
|
*/
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
def logPerplexity(dataset: DataFrame): Double = {
|
def logPerplexity(dataset: Dataset[_]): Double = {
|
||||||
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
|
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
|
||||||
oldLocalModel.logPerplexity(oldDataset)
|
oldLocalModel.logPerplexity(oldDataset)
|
||||||
}
|
}
|
||||||
|
@ -840,8 +840,8 @@ class LDA @Since("1.6.0") (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def copy(extra: ParamMap): LDA = defaultCopy(extra)
|
override def copy(extra: ParamMap): LDA = defaultCopy(extra)
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): LDAModel = {
|
override def fit(dataset: Dataset[_]): LDAModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val oldLDA = new OldLDA()
|
val oldLDA = new OldLDA()
|
||||||
.setK($(k))
|
.setK($(k))
|
||||||
|
@ -873,7 +873,7 @@ class LDA @Since("1.6.0") (
|
||||||
private[clustering] object LDA extends DefaultParamsReadable[LDA] {
|
private[clustering] object LDA extends DefaultParamsReadable[LDA] {
|
||||||
|
|
||||||
/** Get dataset for spark.mllib LDA */
|
/** Get dataset for spark.mllib LDA */
|
||||||
def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
|
def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = {
|
||||||
dataset
|
dataset
|
||||||
.withColumn("docId", monotonicallyIncreasingId())
|
.withColumn("docId", monotonicallyIncreasingId())
|
||||||
.select("docId", featuresCol)
|
.select("docId", featuresCol)
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
||||||
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
|
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.types.DoubleType
|
import org.apache.spark.sql.types.DoubleType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
|
||||||
|
|
||||||
setDefault(metricName -> "areaUnderROC")
|
setDefault(metricName -> "areaUnderROC")
|
||||||
|
|
||||||
@Since("1.2.0")
|
@Since("2.0.0")
|
||||||
override def evaluate(dataset: DataFrame): Double = {
|
override def evaluate(dataset: Dataset[_]): Double = {
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
|
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
|
||||||
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
|
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
|
||||||
|
|
||||||
import org.apache.spark.annotation.{DeveloperApi, Since}
|
import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
import org.apache.spark.ml.param.{ParamMap, Params}
|
import org.apache.spark.ml.param.{ParamMap, Params}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.Dataset
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
|
@ -36,8 +36,8 @@ abstract class Evaluator extends Params {
|
||||||
* @param paramMap parameter map that specifies the input columns and output metrics
|
* @param paramMap parameter map that specifies the input columns and output metrics
|
||||||
* @return metric
|
* @return metric
|
||||||
*/
|
*/
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
|
def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = {
|
||||||
this.copy(paramMap).evaluate(dataset)
|
this.copy(paramMap).evaluate(dataset)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,8 +46,8 @@ abstract class Evaluator extends Params {
|
||||||
* @param dataset a dataset that contains labels/observations and predictions.
|
* @param dataset a dataset that contains labels/observations and predictions.
|
||||||
* @return metric
|
* @return metric
|
||||||
*/
|
*/
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
def evaluate(dataset: DataFrame): Double
|
def evaluate(dataset: Dataset[_]): Double
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
|
* Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
||||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
||||||
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.types.DoubleType
|
import org.apache.spark.sql.types.DoubleType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
||||||
|
|
||||||
setDefault(metricName -> "f1")
|
setDefault(metricName -> "f1")
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def evaluate(dataset: DataFrame): Double = {
|
override def evaluate(dataset: Dataset[_]): Double = {
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
|
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
|
||||||
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
|
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
||||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
|
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
|
||||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DoubleType, FloatType}
|
import org.apache.spark.sql.types.{DoubleType, FloatType}
|
||||||
|
|
||||||
|
@ -70,8 +70,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
|
||||||
|
|
||||||
setDefault(metricName -> "rmse")
|
setDefault(metricName -> "rmse")
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("2.0.0")
|
||||||
override def evaluate(dataset: DataFrame): Double = {
|
override def evaluate(dataset: Dataset[_]): Double = {
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
val predictionColName = $(predictionCol)
|
val predictionColName = $(predictionCol)
|
||||||
val predictionType = schema($(predictionCol)).dataType
|
val predictionType = schema($(predictionCol)).dataType
|
||||||
|
|
|
@ -64,7 +64,8 @@ final class Binarizer(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val outputSchema = transformSchema(dataset.schema, logging = true)
|
val outputSchema = transformSchema(dataset.schema, logging = true)
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
val inputType = schema($(inputCol)).dataType
|
val inputType = schema($(inputCol)).dataType
|
||||||
|
|
|
@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema)
|
transformSchema(dataset.schema)
|
||||||
val bucketizer = udf { feature: Double =>
|
val bucketizer = udf { feature: Double =>
|
||||||
Bucketizer.binarySearchForBuckets($(splits), feature)
|
Bucketizer.binarySearchForBuckets($(splits), feature)
|
||||||
|
|
|
@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): ChiSqSelectorModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
|
val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
|
||||||
case Row(label: Double, features: Vector) =>
|
case Row(label: Double, features: Vector) =>
|
||||||
|
@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val transformedSchema = transformSchema(dataset.schema, logging = true)
|
val transformedSchema = transformSchema(dataset.schema, logging = true)
|
||||||
val newField = transformedSchema.last
|
val newField = transformedSchema.last
|
||||||
val selector = udf { chiSqSelector.transform _ }
|
val selector = udf { chiSqSelector.transform _ }
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.collection.OpenHashMap
|
import org.apache.spark.util.collection.OpenHashMap
|
||||||
|
@ -147,7 +147,8 @@ class CountVectorizer(override val uid: String)
|
||||||
|
|
||||||
setDefault(vocabSize -> (1 << 18), minDF -> 1)
|
setDefault(vocabSize -> (1 << 18), minDF -> 1)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): CountVectorizerModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): CountVectorizerModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val vocSize = $(vocabSize)
|
val vocSize = $(vocabSize)
|
||||||
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
|
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
|
||||||
|
@ -224,7 +225,8 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
|
||||||
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
|
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
|
||||||
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
|
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
if (broadcastDict.isEmpty) {
|
if (broadcastDict.isEmpty) {
|
||||||
val dict = vocabulary.zipWithIndex.toMap
|
val dict = vocabulary.zipWithIndex.toMap
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidat
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.feature
|
import org.apache.spark.mllib.feature
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{ArrayType, StructType}
|
import org.apache.spark.sql.types.{ArrayType, StructType}
|
||||||
|
|
||||||
|
@ -77,7 +77,8 @@ class HashingTF(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setBinary(value: Boolean): this.type = set(binary, value)
|
def setBinary(value: Boolean): this.type = set(binary, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val outputSchema = transformSchema(dataset.schema)
|
val outputSchema = transformSchema(dataset.schema)
|
||||||
val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
|
val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
|
||||||
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
|
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
|
||||||
|
|
|
@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
|
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): IDFModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): IDFModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
||||||
val idf = new feature.IDF($(minDocFreq)).fit(input)
|
val idf = new feature.IDF($(minDocFreq)).fit(input)
|
||||||
|
@ -115,7 +116,8 @@ class IDFModel private[ml] (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val idf = udf { vec: Vector => idfModel.transform(vec) }
|
val idf = udf { vec: Vector => idfModel.transform(vec) }
|
||||||
dataset.withColumn($(outputCol), idf(col($(inputCol))))
|
dataset.withColumn($(outputCol), idf(col($(inputCol))))
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
|
||||||
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
|
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
|
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
|
||||||
val featureEncoders = getFeatureEncoders(inputFeatures)
|
val featureEncoders = getFeatureEncoders(inputFeatures)
|
||||||
val featureAttrs = getFeatureAttrs(inputFeatures)
|
val featureAttrs = getFeatureAttrs(inputFeatures)
|
||||||
|
|
|
@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): MaxAbsScalerModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
||||||
val summary = Statistics.colStats(input)
|
val summary = Statistics.colStats(input)
|
||||||
|
@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
// TODO: this looks hack, we may have to handle sparse and dense vectors separately.
|
// TODO: this looks hack, we may have to handle sparse and dense vectors separately.
|
||||||
val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x))
|
val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x))
|
||||||
|
|
|
@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setMax(value: Double): this.type = set(max, value)
|
def setMax(value: Double): this.type = set(max, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): MinMaxScalerModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
||||||
val summary = Statistics.colStats(input)
|
val summary = Statistics.colStats(input)
|
||||||
|
@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setMax(value: Double): this.type = set(max, value)
|
def setMax(value: Double): this.type = set(max, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
|
val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
|
||||||
val minArray = originalMin.toArray
|
val minArray = originalMin.toArray
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
|
import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
|
||||||
|
|
||||||
|
@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
// schema transformation
|
// schema transformation
|
||||||
val inputColName = $(inputCol)
|
val inputColName = $(inputCol)
|
||||||
val outputColName = $(outputCol)
|
val outputColName = $(outputCol)
|
||||||
|
|
|
@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
|
||||||
/**
|
/**
|
||||||
* Computes a [[PCAModel]] that contains the principal components of the input vectors.
|
* Computes a [[PCAModel]] that contains the principal components of the input vectors.
|
||||||
*/
|
*/
|
||||||
override def fit(dataset: DataFrame): PCAModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): PCAModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v}
|
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v}
|
||||||
val pca = new feature.PCA(k = $(k))
|
val pca = new feature.PCA(k = $(k))
|
||||||
|
@ -124,7 +125,8 @@ class PCAModel private[ml] (
|
||||||
* NOTE: Vectors to be transformed must be the same length
|
* NOTE: Vectors to be transformed must be the same length
|
||||||
* as the source vectors given to [[PCA.fit()]].
|
* as the source vectors given to [[PCA.fit()]].
|
||||||
*/
|
*/
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
|
val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
|
||||||
val pcaOp = udf { pcaModel.transform _ }
|
val pcaOp = udf { pcaModel.transform _ }
|
||||||
|
|
|
@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since}
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml._
|
import org.apache.spark.ml._
|
||||||
import org.apache.spark.ml.attribute.NominalAttribute
|
import org.apache.spark.ml.attribute.NominalAttribute
|
||||||
import org.apache.spark.ml.param.{IntParam, _}
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.types.{DoubleType, StructType}
|
import org.apache.spark.sql.types.{DoubleType, StructType}
|
||||||
import org.apache.spark.util.random.XORShiftRandom
|
import org.apache.spark.util.random.XORShiftRandom
|
||||||
|
|
||||||
|
@ -87,7 +87,8 @@ final class QuantileDiscretizer(override val uid: String)
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): Bucketizer = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): Bucketizer = {
|
||||||
val samples = QuantileDiscretizer
|
val samples = QuantileDiscretizer
|
||||||
.getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
|
.getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
|
||||||
.map { case Row(feature: Double) => feature }
|
.map { case Row(feature: Double) => feature }
|
||||||
|
@ -112,13 +113,15 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
|
||||||
/**
|
/**
|
||||||
* Sampling from the given dataset to collect quantile statistics.
|
* Sampling from the given dataset to collect quantile statistics.
|
||||||
*/
|
*/
|
||||||
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
|
private[feature]
|
||||||
|
def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
|
||||||
val totalSamples = dataset.count()
|
val totalSamples = dataset.count()
|
||||||
require(totalSamples > 0,
|
require(totalSamples > 0,
|
||||||
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
|
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
|
||||||
val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
|
val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
|
||||||
val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
|
val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
|
||||||
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
|
dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap}
|
||||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
|
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.VectorUDT
|
import org.apache.spark.mllib.linalg.VectorUDT
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -103,7 +103,8 @@ class RFormula(override val uid: String)
|
||||||
RFormulaParser.parse($(formula)).hasIntercept
|
RFormulaParser.parse($(formula)).hasIntercept
|
||||||
}
|
}
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): RFormulaModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): RFormulaModel = {
|
||||||
require(isDefined(formula), "Formula must be defined first.")
|
require(isDefined(formula), "Formula must be defined first.")
|
||||||
val parsedFormula = RFormulaParser.parse($(formula))
|
val parsedFormula = RFormulaParser.parse($(formula))
|
||||||
val resolvedFormula = parsedFormula.resolve(dataset.schema)
|
val resolvedFormula = parsedFormula.resolve(dataset.schema)
|
||||||
|
@ -204,7 +205,8 @@ class RFormulaModel private[feature](
|
||||||
private[ml] val pipelineModel: PipelineModel)
|
private[ml] val pipelineModel: PipelineModel)
|
||||||
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
|
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
checkCanTransform(dataset.schema)
|
checkCanTransform(dataset.schema)
|
||||||
transformLabel(pipelineModel.transform(dataset))
|
transformLabel(pipelineModel.transform(dataset))
|
||||||
}
|
}
|
||||||
|
@ -232,10 +234,10 @@ class RFormulaModel private[feature](
|
||||||
|
|
||||||
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
|
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
|
||||||
|
|
||||||
private def transformLabel(dataset: DataFrame): DataFrame = {
|
private def transformLabel(dataset: Dataset[_]): DataFrame = {
|
||||||
val labelName = resolvedFormula.label
|
val labelName = resolvedFormula.label
|
||||||
if (hasLabelCol(dataset.schema)) {
|
if (hasLabelCol(dataset.schema)) {
|
||||||
dataset
|
dataset.toDF
|
||||||
} else if (dataset.schema.exists(_.name == labelName)) {
|
} else if (dataset.schema.exists(_.name == labelName)) {
|
||||||
dataset.schema(labelName).dataType match {
|
dataset.schema(labelName).dataType match {
|
||||||
case _: NumericType | BooleanType =>
|
case _: NumericType | BooleanType =>
|
||||||
|
@ -246,7 +248,7 @@ class RFormulaModel private[feature](
|
||||||
} else {
|
} else {
|
||||||
// Ignore the label field. This is a hack so that this transformer can also work on test
|
// Ignore the label field. This is a hack so that this transformer can also work on test
|
||||||
// datasets in a Pipeline.
|
// datasets in a Pipeline.
|
||||||
dataset
|
dataset.toDF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str
|
||||||
def this(columnsToPrune: Set[String]) =
|
def this(columnsToPrune: Set[String]) =
|
||||||
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
|
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
|
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
|
||||||
dataset.select(columnsToKeep.map(dataset.col): _*)
|
dataset.select(columnsToKeep.map(dataset.col): _*)
|
||||||
}
|
}
|
||||||
|
@ -396,7 +398,7 @@ private class VectorAttributeRewriter(
|
||||||
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
|
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
|
||||||
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
|
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val metadata = {
|
val metadata = {
|
||||||
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
|
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
|
||||||
val attrs = group.attributes.get.map { attr =>
|
val attrs = group.attributes.get.map { attr =>
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since}
|
||||||
import org.apache.spark.ml.param.{Param, ParamMap}
|
import org.apache.spark.ml.param.{Param, ParamMap}
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -63,8 +63,8 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
|
||||||
|
|
||||||
private val tableIdentifier: String = "__THIS__"
|
private val tableIdentifier: String = "__THIS__"
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val tableName = Identifiable.randomUID(uid)
|
val tableName = Identifiable.randomUID(uid)
|
||||||
dataset.registerTempTable(tableName)
|
dataset.registerTempTable(tableName)
|
||||||
val realStatement = $(statement).replace(tableIdentifier, tableName)
|
val realStatement = $(statement).replace(tableIdentifier, tableName)
|
||||||
|
|
|
@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setWithStd(value: Boolean): this.type = set(withStd, value)
|
def setWithStd(value: Boolean): this.type = set(withStd, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): StandardScalerModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): StandardScalerModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
|
||||||
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
|
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
|
||||||
|
@ -135,7 +136,8 @@ class StandardScalerModel private[ml] (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
|
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
|
||||||
val scale = udf { scaler.transform _ }
|
val scale = udf { scaler.transform _ }
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
|
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions.{col, udf}
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
|
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
|
||||||
|
|
||||||
|
@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String)
|
||||||
|
|
||||||
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
|
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val outputSchema = transformSchema(dataset.schema)
|
val outputSchema = transformSchema(dataset.schema)
|
||||||
val t = if ($(caseSensitive)) {
|
val t = if ($(caseSensitive)) {
|
||||||
val stopWordsSet = $(stopWords).toSet
|
val stopWordsSet = $(stopWords).toSet
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.collection.OpenHashMap
|
import org.apache.spark.util.collection.OpenHashMap
|
||||||
|
@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): StringIndexerModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): StringIndexerModel = {
|
||||||
val counts = dataset.select(col($(inputCol)).cast(StringType))
|
val counts = dataset.select(col($(inputCol)).cast(StringType))
|
||||||
.rdd
|
.rdd
|
||||||
.map(_.getString(0))
|
.map(_.getString(0))
|
||||||
|
@ -144,11 +145,12 @@ class StringIndexerModel (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
if (!dataset.schema.fieldNames.contains($(inputCol))) {
|
if (!dataset.schema.fieldNames.contains($(inputCol))) {
|
||||||
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
|
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
|
||||||
"Skip StringIndexerModel.")
|
"Skip StringIndexerModel.")
|
||||||
return dataset
|
return dataset.toDF
|
||||||
}
|
}
|
||||||
validateAndTransformSchema(dataset.schema)
|
validateAndTransformSchema(dataset.schema)
|
||||||
|
|
||||||
|
@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String)
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val inputColSchema = dataset.schema($(inputCol))
|
val inputColSchema = dataset.schema($(inputCol))
|
||||||
// If the labels array is empty use column metadata
|
// If the labels array is empty use column metadata
|
||||||
val values = if ($(labels).isEmpty) {
|
val values = if ($(labels).isEmpty) {
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
// Schema transformation.
|
// Schema transformation.
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
lazy val first = dataset.first()
|
lazy val first = dataset.toDF.first()
|
||||||
val attrs = $(inputCols).flatMap { c =>
|
val attrs = $(inputCols).flatMap { c =>
|
||||||
val field = schema(c)
|
val field = schema(c)
|
||||||
val index = schema.fieldIndex(c)
|
val index = schema.fieldIndex(c)
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.udf
|
import org.apache.spark.sql.functions.udf
|
||||||
import org.apache.spark.sql.types.{StructField, StructType}
|
import org.apache.spark.sql.types.{StructField, StructType}
|
||||||
import org.apache.spark.util.collection.OpenHashSet
|
import org.apache.spark.util.collection.OpenHashSet
|
||||||
|
@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): VectorIndexerModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): VectorIndexerModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val firstRow = dataset.select($(inputCol)).take(1)
|
val firstRow = dataset.select($(inputCol)).take(1)
|
||||||
require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
|
require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
|
||||||
|
@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] (
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val newField = prepOutputField(dataset.schema)
|
val newField = prepOutputField(dataset.schema)
|
||||||
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
|
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg._
|
import org.apache.spark.mllib.linalg._
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String)
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
// Validity checks
|
// Validity checks
|
||||||
transformSchema(dataset.schema)
|
transformSchema(dataset.schema)
|
||||||
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
|
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.feature
|
import org.apache.spark.mllib.feature
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setMinCount(value: Int): this.type = set(minCount, value)
|
def setMinCount(value: Int): this.type = set(minCount, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): Word2VecModel = {
|
@Since("2.0.0")
|
||||||
|
override def fit(dataset: Dataset[_]): Word2VecModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
|
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
|
||||||
val wordVectors = new feature.Word2Vec()
|
val wordVectors = new feature.Word2Vec()
|
||||||
|
@ -219,7 +220,8 @@ class Word2VecModel private[ml] (
|
||||||
* Transform a sentence column to a vector column to represent the whole sentence. The transform
|
* Transform a sentence column to a vector column to represent the whole sentence. The transform
|
||||||
* is performed by averaging all word vectors it contains.
|
* is performed by averaging all word vectors it contains.
|
||||||
*/
|
*/
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val vectors = wordVectors.getVectors
|
val vectors = wordVectors.getVectors
|
||||||
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
|
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup
|
import org.apache.spark.ml.attribute.AttributeGroup
|
||||||
import org.apache.spark.ml.feature.RFormula
|
import org.apache.spark.ml.feature.RFormula
|
||||||
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
|
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
private[r] class AFTSurvivalRegressionWrapper private (
|
private[r] class AFTSurvivalRegressionWrapper private (
|
||||||
pipeline: PipelineModel,
|
pipeline: PipelineModel,
|
||||||
|
@ -43,7 +43,7 @@ private[r] class AFTSurvivalRegressionWrapper private (
|
||||||
features ++ Array("Log(scale)")
|
features ++ Array("Log(scale)")
|
||||||
}
|
}
|
||||||
|
|
||||||
def transform(dataset: DataFrame): DataFrame = {
|
def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
pipeline.transform(dataset)
|
pipeline.transform(dataset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup
|
import org.apache.spark.ml.attribute.AttributeGroup
|
||||||
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
|
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
private[r] class KMeansWrapper private (
|
private[r] class KMeansWrapper private (
|
||||||
pipeline: PipelineModel) {
|
pipeline: PipelineModel) {
|
||||||
|
@ -52,7 +52,7 @@ private[r] class KMeansWrapper private (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def transform(dataset: DataFrame): DataFrame = {
|
def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
|
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
|
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
|
||||||
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
|
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
|
||||||
import org.apache.spark.ml.feature.{IndexToString, RFormula}
|
import org.apache.spark.ml.feature.{IndexToString, RFormula}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
private[r] class NaiveBayesWrapper private (
|
private[r] class NaiveBayesWrapper private (
|
||||||
pipeline: PipelineModel,
|
pipeline: PipelineModel,
|
||||||
|
@ -36,7 +36,7 @@ private[r] class NaiveBayesWrapper private (
|
||||||
|
|
||||||
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
|
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
|
||||||
|
|
||||||
def transform(dataset: DataFrame): DataFrame = {
|
def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
|
pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.CholeskyDecomposition
|
import org.apache.spark.mllib.linalg.CholeskyDecomposition
|
||||||
import org.apache.spark.mllib.optimization.NNLS
|
import org.apache.spark.mllib.optimization.NNLS
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
|
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
@ -200,8 +200,8 @@ class ALSModel private[ml] (
|
||||||
@Since("1.3.0")
|
@Since("1.3.0")
|
||||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||||
|
|
||||||
@Since("1.3.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
// Register a UDF for DataFrame, and then
|
// Register a UDF for DataFrame, and then
|
||||||
// create a new column named map(predictionCol) by running the predict UDF.
|
// create a new column named map(predictionCol) by running the predict UDF.
|
||||||
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
|
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
|
||||||
|
@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.3.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): ALSModel = {
|
override def fit(dataset: Dataset[_]): ALSModel = {
|
||||||
import dataset.sqlContext.implicits._
|
import dataset.sqlContext.implicits._
|
||||||
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
|
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
|
||||||
val ratings = dataset
|
val ratings = dataset
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
|
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DoubleType, StructType}
|
import org.apache.spark.sql.types.{DoubleType, StructType}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
@ -183,7 +183,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
||||||
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
|
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
|
||||||
* and put it in an RDD with strong types.
|
* and put it in an RDD with strong types.
|
||||||
*/
|
*/
|
||||||
protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
|
protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
|
||||||
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
|
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
|
||||||
.rdd.map {
|
.rdd.map {
|
||||||
case Row(features: Vector, label: Double, censor: Double) =>
|
case Row(features: Vector, label: Double, censor: Double) =>
|
||||||
|
@ -191,8 +191,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
|
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
|
||||||
validateAndTransformSchema(dataset.schema, fitting = true)
|
validateAndTransformSchema(dataset.schema, fitting = true)
|
||||||
val instances = extractAFTPoints(dataset)
|
val instances = extractAFTPoints(dataset)
|
||||||
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
|
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
|
||||||
|
@ -299,8 +299,8 @@ class AFTSurvivalRegressionModel private[ml] (
|
||||||
math.exp(BLAS.dot(coefficients, features) + intercept)
|
math.exp(BLAS.dot(coefficients, features) + intercept)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema)
|
transformSchema(dataset.schema)
|
||||||
val predictUDF = udf { features: Vector => predict(features) }
|
val predictUDF = udf { features: Vector => predict(features) }
|
||||||
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
|
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setVarianceCol(value: String): this.type = set(varianceCol, value)
|
def setVarianceCol(value: String): this.type = set(varianceCol, value)
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
|
override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||||
|
@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] (
|
||||||
rootNode.predictImpl(features).impurityStats.calculate()
|
rootNode.predictImpl(features).impurityStats.calculate()
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
@Since("2.0.0")
|
||||||
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
transformImpl(dataset)
|
transformImpl(dataset)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def transformImpl(dataset: DataFrame): DataFrame = {
|
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||||
val predictUDF = udf { (features: Vector) => predict(features) }
|
val predictUDF = udf { (features: Vector) => predict(features) }
|
||||||
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
|
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
|
||||||
var output = dataset
|
var output = dataset.toDF
|
||||||
if ($(predictionCol).nonEmpty) {
|
if ($(predictionCol).nonEmpty) {
|
||||||
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
|
||||||
SquaredError => OldSquaredError}
|
SquaredError => OldSquaredError}
|
||||||
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
|
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -147,7 +147,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): GBTRegressionModel = {
|
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||||
|
@ -209,7 +209,7 @@ final class GBTRegressionModel private[ml](
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def treeWeights: Array[Double] = _treeWeights
|
override def treeWeights: Array[Double] = _treeWeights
|
||||||
|
|
||||||
override protected def transformImpl(dataset: DataFrame): DataFrame = {
|
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||||
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
||||||
val predictUDF = udf { (features: Any) =>
|
val predictUDF = udf { (features: Any) =>
|
||||||
bcastModel.value.predict(features.asInstanceOf[Vector])
|
bcastModel.value.predict(features.asInstanceOf[Vector])
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, Vector}
|
import org.apache.spark.mllib.linalg.{BLAS, Vector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
||||||
def setSolver(value: String): this.type = set(solver, value)
|
def setSolver(value: String): this.type = set(solver, value)
|
||||||
setDefault(solver -> "irls")
|
setDefault(solver -> "irls")
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
|
override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
|
||||||
val familyObj = Family.fromName($(family))
|
val familyObj = Family.fromName($(family))
|
||||||
val linkObj = if (isDefined(link)) {
|
val linkObj = if (isDefined(link)) {
|
||||||
Link.fromName($(link))
|
Link.fromName($(link))
|
||||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
|
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
|
||||||
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
|
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.{col, lit, udf}
|
import org.apache.spark.sql.functions.{col, lit, udf}
|
||||||
import org.apache.spark.sql.types.{DoubleType, StructType}
|
import org.apache.spark.sql.types.{DoubleType, StructType}
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
|
||||||
* Extracts (label, feature, weight) from input dataset.
|
* Extracts (label, feature, weight) from input dataset.
|
||||||
*/
|
*/
|
||||||
protected[ml] def extractWeightedLabeledPoints(
|
protected[ml] def extractWeightedLabeledPoints(
|
||||||
dataset: DataFrame): RDD[(Double, Double, Double)] = {
|
dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
|
||||||
val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
|
val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
|
||||||
val idx = $(featureIndex)
|
val idx = $(featureIndex)
|
||||||
val extract = udf { v: Vector => v(idx) }
|
val extract = udf { v: Vector => v(idx) }
|
||||||
|
@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
|
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): IsotonicRegressionModel = {
|
override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
|
||||||
validateAndTransformSchema(dataset.schema, fitting = true)
|
validateAndTransformSchema(dataset.schema, fitting = true)
|
||||||
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
|
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
|
||||||
val instances = extractWeightedLabeledPoints(dataset)
|
val instances = extractWeightedLabeledPoints(dataset)
|
||||||
|
@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] (
|
||||||
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
|
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
val predict = dataset.schema($(featuresCol)).dataType match {
|
val predict = dataset.schema($(featuresCol)).dataType match {
|
||||||
case DoubleType =>
|
case DoubleType =>
|
||||||
udf { feature: Double => oldModel.predict(feature) }
|
udf { feature: Double => oldModel.predict(feature) }
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.linalg.BLAS._
|
import org.apache.spark.mllib.linalg.BLAS._
|
||||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.DoubleType
|
import org.apache.spark.sql.types.DoubleType
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
||||||
def setSolver(value: String): this.type = set(solver, value)
|
def setSolver(value: String): this.type = set(solver, value)
|
||||||
setDefault(solver -> "auto")
|
setDefault(solver -> "auto")
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): LinearRegressionModel = {
|
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
|
||||||
// Extract the number of features before deciding optimization solver.
|
// Extract the number of features before deciding optimization solver.
|
||||||
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
|
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
|
||||||
case Row(features: Vector) => features.size
|
case Row(features: Vector) => features.size
|
||||||
|
@ -417,7 +417,7 @@ class LinearRegressionModel private[ml] (
|
||||||
* @param dataset Test dataset to evaluate model on.
|
* @param dataset Test dataset to evaluate model on.
|
||||||
*/
|
*/
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def evaluate(dataset: DataFrame): LinearRegressionSummary = {
|
def evaluate(dataset: Dataset[_]): LinearRegressionSummary = {
|
||||||
// Handle possible missing or invalid prediction columns
|
// Handle possible missing or invalid prediction columns
|
||||||
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
|
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
|
||||||
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
|
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
|
||||||
override def setFeatureSubsetStrategy(value: String): this.type =
|
override def setFeatureSubsetStrategy(value: String): this.type =
|
||||||
super.setFeatureSubsetStrategy(value)
|
super.setFeatureSubsetStrategy(value)
|
||||||
|
|
||||||
override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
|
override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||||
|
@ -164,7 +164,7 @@ final class RandomForestRegressionModel private[ml] (
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def treeWeights: Array[Double] = _treeWeights
|
override def treeWeights: Array[Double] = _treeWeights
|
||||||
|
|
||||||
override protected def transformImpl(dataset: DataFrame): DataFrame = {
|
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||||
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
|
||||||
val predictUDF = udf { (features: Any) =>
|
val predictUDF = udf { (features: Any) =>
|
||||||
bcastModel.value.predict(features.asInstanceOf[Vector])
|
bcastModel.value.predict(features.asInstanceOf[Vector])
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared.HasSeed
|
import org.apache.spark.ml.param.shared.HasSeed
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -90,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
def setSeed(value: Long): this.type = set(seed, value)
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): CrossValidatorModel = {
|
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
transformSchema(schema, logging = true)
|
transformSchema(schema, logging = true)
|
||||||
val sqlCtx = dataset.sqlContext
|
val sqlCtx = dataset.sqlContext
|
||||||
|
@ -100,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
|
||||||
val epm = $(estimatorParamMaps)
|
val epm = $(estimatorParamMaps)
|
||||||
val numModels = epm.length
|
val numModels = epm.length
|
||||||
val metrics = new Array[Double](epm.length)
|
val metrics = new Array[Double](epm.length)
|
||||||
val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
|
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
|
||||||
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
|
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
|
||||||
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
|
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
|
||||||
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
|
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
|
||||||
|
@ -209,8 +209,8 @@ class CrossValidatorModel private[ml] (
|
||||||
this(uid, bestModel, avgMetrics.asScala.toArray)
|
this(uid, bestModel, avgMetrics.asScala.toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
bestModel.transform(dataset)
|
bestModel.transform(dataset)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.ml.tuning
|
||||||
import java.util.{List => JList}
|
import java.util.{List => JList}
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.language.existentials
|
||||||
|
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
|
@ -31,7 +32,7 @@ import org.apache.spark.ml.evaluation.Evaluator
|
||||||
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
|
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.HasSeed
|
import org.apache.spark.ml.param.shared.HasSeed
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -89,8 +90,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
def setSeed(value: Long): this.type = set(seed, value)
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def fit(dataset: DataFrame): TrainValidationSplitModel = {
|
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
|
||||||
val schema = dataset.schema
|
val schema = dataset.schema
|
||||||
transformSchema(schema, logging = true)
|
transformSchema(schema, logging = true)
|
||||||
val sqlCtx = dataset.sqlContext
|
val sqlCtx = dataset.sqlContext
|
||||||
|
@ -207,8 +208,8 @@ class TrainValidationSplitModel private[ml] (
|
||||||
this(uid, bestModel, validationMetrics.asScala.toArray)
|
this(uid, bestModel, validationMetrics.asScala.toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("2.0.0")
|
||||||
override def transform(dataset: DataFrame): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
bestModel.transform(dataset)
|
bestModel.transform(dataset)
|
||||||
}
|
}
|
||||||
|
|
|
@ -186,7 +186,7 @@ sealed trait Vector extends Serializable {
|
||||||
* :: AlphaComponent ::
|
* :: AlphaComponent ::
|
||||||
*
|
*
|
||||||
* User-defined type for [[Vector]] which allows easy interaction with SQL
|
* User-defined type for [[Vector]] which allows easy interaction with SQL
|
||||||
* via [[org.apache.spark.sql.DataFrame]].
|
* via [[org.apache.spark.sql.Dataset]].
|
||||||
*/
|
*/
|
||||||
@AlphaComponent
|
@AlphaComponent
|
||||||
class VectorUDT extends UserDefinedType[Vector] {
|
class VectorUDT extends UserDefinedType[Vector] {
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
||||||
val dataset3 = mock[DataFrame]
|
val dataset3 = mock[DataFrame]
|
||||||
val dataset4 = mock[DataFrame]
|
val dataset4 = mock[DataFrame]
|
||||||
|
|
||||||
|
when(dataset0.toDF).thenReturn(dataset0)
|
||||||
|
when(dataset1.toDF).thenReturn(dataset1)
|
||||||
|
when(dataset2.toDF).thenReturn(dataset2)
|
||||||
|
when(dataset3.toDF).thenReturn(dataset3)
|
||||||
|
when(dataset4.toDF).thenReturn(dataset4)
|
||||||
|
|
||||||
when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
|
when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
|
||||||
when(model0.copy(any[ParamMap])).thenReturn(model0)
|
when(model0.copy(any[ParamMap])).thenReturn(model0)
|
||||||
when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
|
when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
|
||||||
|
@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with MLWritabl
|
||||||
|
|
||||||
override def write: MLWriter = new DefaultParamsWriter(this)
|
override def write: MLWriter = new DefaultParamsWriter(this)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = dataset
|
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = schema
|
override def transformSchema(schema: StructType): StructType = schema
|
||||||
}
|
}
|
||||||
|
@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer {
|
||||||
|
|
||||||
override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
|
override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
|
||||||
|
|
||||||
override def transform(dataset: DataFrame): DataFrame = dataset
|
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = schema
|
override def transformSchema(schema: StructType): StructType = schema
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.lit
|
import org.apache.spark.sql.functions.lit
|
||||||
|
|
||||||
class LogisticRegressionSuite
|
class LogisticRegressionSuite
|
||||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
@transient var binaryDataset: DataFrame = _
|
@transient var binaryDataset: DataFrame = _
|
||||||
private val eps: Double = 1e-5
|
private val eps: Double = 1e-5
|
||||||
|
|
||||||
|
|
|
@ -26,12 +26,12 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
|
||||||
class MultilayerPerceptronClassifierSuite
|
class MultilayerPerceptronClassifierSuite
|
||||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
|
@ -27,11 +27,11 @@ import org.apache.spark.mllib.classification.NaiveBayesSuite._
|
||||||
import org.apache.spark.mllib.linalg._
|
import org.apache.spark.mllib.linalg._
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
|
||||||
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
|
@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.Metadata
|
import org.apache.spark.sql.types.Metadata
|
||||||
|
|
||||||
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
@transient var rdd: RDD[LabeledPoint] = _
|
@transient var rdd: RDD[LabeledPoint] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
|
@ -246,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
|
||||||
|
|
||||||
setMaxIter(1)
|
setMaxIter(1)
|
||||||
|
|
||||||
override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
|
override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
|
||||||
val labelSchema = dataset.schema($(labelCol))
|
val labelSchema = dataset.schema($(labelCol))
|
||||||
// check for label attribute propagation.
|
// check for label attribute propagation.
|
||||||
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
|
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
|
||||||
|
|
|
@ -20,13 +20,13 @@ package org.apache.spark.ml.clustering
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
class BisectingKMeansSuite
|
class BisectingKMeansSuite
|
||||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
final val k = 5
|
final val k = 5
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
|
@ -20,14 +20,14 @@ package org.apache.spark.ml.clustering
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
|
|
||||||
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
|
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
with DefaultReadWriteTest {
|
with DefaultReadWriteTest {
|
||||||
|
|
||||||
final val k = 5
|
final val k = 5
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
|
@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
|
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
|
||||||
|
|
||||||
private[clustering] case class TestRow(features: Vector)
|
private[clustering] case class TestRow(features: Vector)
|
||||||
|
|
||||||
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
final val k = 5
|
final val k = 5
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
|
||||||
|
|
||||||
|
|
||||||
object LDASuite {
|
object LDASuite {
|
||||||
|
@ -64,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
|
||||||
|
|
||||||
val k: Int = 5
|
val k: Int = 5
|
||||||
val vocabSize: Int = 30
|
val vocabSize: Int = 30
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
|
|
@ -22,7 +22,7 @@ import scala.beans.BeanInfo
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
|
||||||
@BeanInfo
|
@BeanInfo
|
||||||
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
|
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
|
||||||
|
@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
|
||||||
|
|
||||||
object NGramSuite extends SparkFunSuite {
|
object NGramSuite extends SparkFunSuite {
|
||||||
|
|
||||||
def testNGram(t: NGram, dataset: DataFrame): Unit = {
|
def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
|
||||||
t.transform(dataset)
|
t.transform(dataset)
|
||||||
.select("nGrams", "wantedNGrams")
|
.select("nGrams", "wantedNGrams")
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
|
||||||
object StopWordsRemoverSuite extends SparkFunSuite {
|
object StopWordsRemoverSuite extends SparkFunSuite {
|
||||||
def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
|
def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
|
||||||
t.transform(dataset)
|
t.transform(dataset)
|
||||||
.select("filtered", "expected")
|
.select("filtered", "expected")
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
@ -115,7 +115,7 @@ class StringIndexerSuite
|
||||||
.setInputCol("label")
|
.setInputCol("label")
|
||||||
.setOutputCol("labelIndex")
|
.setOutputCol("labelIndex")
|
||||||
val df = sqlContext.range(0L, 10L).toDF()
|
val df = sqlContext.range(0L, 10L).toDF()
|
||||||
assert(indexerModel.transform(df).eq(df))
|
assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("StringIndexerModel can't overwrite output column") {
|
test("StringIndexerModel can't overwrite output column") {
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.param.ParamsSuite
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
|
||||||
@BeanInfo
|
@BeanInfo
|
||||||
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
|
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
|
||||||
|
@ -106,7 +106,7 @@ class RegexTokenizerSuite
|
||||||
|
|
||||||
object RegexTokenizerSuite extends SparkFunSuite {
|
object RegexTokenizerSuite extends SparkFunSuite {
|
||||||
|
|
||||||
def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
|
def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = {
|
||||||
t.transform(dataset)
|
t.transform(dataset)
|
||||||
.select("tokens", "wantedTokens")
|
.select("tokens", "wantedTokens")
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
@ -992,6 +992,14 @@ class GeneralizedLinearRegressionSuite
|
||||||
assert(expected.coefficients === actual.coefficients)
|
assert(expected.coefficients === actual.coefficients)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("glm accepts Dataset[LabeledPoint]") {
|
||||||
|
val context = sqlContext
|
||||||
|
import context.implicits._
|
||||||
|
new GeneralizedLinearRegression()
|
||||||
|
.setFamily("gaussian")
|
||||||
|
.fit(datasetGaussianIdentity.as[LabeledPoint])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object GeneralizedLinearRegressionSuite {
|
object GeneralizedLinearRegressionSuite {
|
||||||
|
|
|
@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.{StructField, StructType}
|
import org.apache.spark.sql.types.{StructField, StructType}
|
||||||
|
|
||||||
class CrossValidatorSuite
|
class CrossValidatorSuite
|
||||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||||
|
|
||||||
@transient var dataset: DataFrame = _
|
@transient var dataset: Dataset[_] = _
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
|
@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite {
|
||||||
|
|
||||||
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
|
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): MyModel = {
|
override def fit(dataset: Dataset[_]): MyModel = {
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite {
|
||||||
|
|
||||||
class MyEvaluator extends Evaluator {
|
class MyEvaluator extends Evaluator {
|
||||||
|
|
||||||
override def evaluate(dataset: DataFrame): Double = {
|
override def evaluate(dataset: Dataset[_]): Double = {
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
class TrainValidationSplitSuite
|
class TrainValidationSplitSuite
|
||||||
|
@ -158,7 +158,7 @@ object TrainValidationSplitSuite {
|
||||||
|
|
||||||
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
|
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): MyModel = {
|
override def fit(dataset: Dataset[_]): MyModel = {
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ object TrainValidationSplitSuite {
|
||||||
|
|
||||||
class MyEvaluator extends Evaluator {
|
class MyEvaluator extends Evaluator {
|
||||||
|
|
||||||
override def evaluate(dataset: DataFrame): Double = {
|
override def evaluate(dataset: Dataset[_]): Double = {
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
|
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
|
||||||
def testEstimatorAndModelReadWrite[
|
def testEstimatorAndModelReadWrite[
|
||||||
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
|
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
|
||||||
estimator: E,
|
estimator: E,
|
||||||
dataset: DataFrame,
|
dataset: Dataset[_],
|
||||||
testParams: Map[String, Any],
|
testParams: Map[String, Any],
|
||||||
checkModelData: (M, M) => Unit): Unit = {
|
checkModelData: (M, M) => Unit): Unit = {
|
||||||
// Set some Params to make sure set Params are serialized.
|
// Set some Params to make sure set Params are serialized.
|
||||||
|
|
Loading…
Reference in a new issue