[SPARK-8151] [MLLIB] pipeline components should correctly implement copy
Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #6622 from mengxr/SPARK-8087 and squashes the following commits:
0e4c8c4 [Xiangrui Meng] fix merge issues
26fc1f0 [Xiangrui Meng] address comments
e607a04 [Xiangrui Meng] merge master
b85b57e [Xiangrui Meng] fix examples/compile
d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy
84ec278 [Xiangrui Meng] remove setter checks due to generics
2cf2ed0 [Xiangrui Meng] snapshot
291814f [Xiangrui Meng] OneVsRest.copy
1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages
(cherry picked from commit 43c7ec6384
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
164b9d32e7
commit
1f2dafb77f
|
@ -156,6 +156,11 @@ class MyJavaLogisticRegression
|
||||||
// Create a model, and return it.
|
// Create a model, and return it.
|
||||||
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
|
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MyJavaLogisticRegression copy(ParamMap extra) {
|
||||||
|
return defaultCopy(extra);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String)
|
||||||
// Create a model, and return it.
|
// Create a model, and return it.
|
||||||
new MyLogisticRegressionModel(uid, weights).setParent(this)
|
new MyLogisticRegressionModel(uid, weights).setParent(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
|
||||||
paramMaps.map(fit(dataset, _))
|
paramMaps.map(fit(dataset, _))
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): Estimator[M] = {
|
override def copy(extra: ParamMap): Estimator[M]
|
||||||
super.copy(extra).asInstanceOf[Estimator[M]]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
|
||||||
/** Indicates whether this [[Model]] has a corresponding parent. */
|
/** Indicates whether this [[Model]] has a corresponding parent. */
|
||||||
def hasParent: Boolean = parent != null
|
def hasParent: Boolean = parent != null
|
||||||
|
|
||||||
override def copy(extra: ParamMap): M = {
|
override def copy(extra: ParamMap): M
|
||||||
// The default implementation of Params.copy doesn't work for models.
|
|
||||||
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,9 +63,7 @@ abstract class PipelineStage extends Params with Logging {
|
||||||
outputSchema
|
outputSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): PipelineStage = {
|
override def copy(extra: ParamMap): PipelineStage
|
||||||
super.copy(extra).asInstanceOf[PipelineStage]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -190,6 +188,6 @@ class PipelineModel private[ml] (
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): PipelineModel = {
|
override def copy(extra: ParamMap): PipelineModel = {
|
||||||
new PipelineModel(uid, stages)
|
new PipelineModel(uid, stages.map(_.copy(extra)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,9 +90,7 @@ abstract class Predictor[
|
||||||
copyValues(train(dataset).setParent(this))
|
copyValues(train(dataset).setParent(this))
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): Learner = {
|
override def copy(extra: ParamMap): Learner
|
||||||
super.copy(extra).asInstanceOf[Learner]
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a model using the given dataset and parameters.
|
* Train a model using the given dataset and parameters.
|
||||||
|
|
|
@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
|
||||||
*/
|
*/
|
||||||
def transform(dataset: DataFrame): DataFrame
|
def transform(dataset: DataFrame): DataFrame
|
||||||
|
|
||||||
override def copy(extra: ParamMap): Transformer = {
|
override def copy(extra: ParamMap): Transformer
|
||||||
super.copy(extra).asInstanceOf[Transformer]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
|
||||||
dataset.withColumn($(outputCol),
|
dataset.withColumn($(outputCol),
|
||||||
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
|
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): T = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.ml.classification
|
package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.DeveloperApi
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
|
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
|
||||||
import org.apache.spark.ml.param.shared.HasRawPredictionCol
|
import org.apache.spark.ml.param.shared.HasRawPredictionCol
|
||||||
import org.apache.spark.ml.util.SchemaUtils
|
import org.apache.spark.ml.util.SchemaUtils
|
||||||
|
|
|
@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
|
||||||
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
|
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
|
||||||
subsamplingRate = 1.0)
|
subsamplingRate = 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
|
|
|
@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
|
||||||
val oldModel = oldGBT.run(oldDataset)
|
val oldModel = oldGBT.run(oldDataset)
|
||||||
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
|
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
|
|
|
@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String)
|
||||||
|
|
||||||
new LogisticRegressionModel(uid, weights.compressed, intercept)
|
new LogisticRegressionModel(uid, weights.compressed, intercept)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -24,7 +24,7 @@ import scala.language.existentials
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml._
|
import org.apache.spark.ml._
|
||||||
import org.apache.spark.ml.attribute._
|
import org.apache.spark.ml.attribute._
|
||||||
import org.apache.spark.ml.param.Param
|
import org.apache.spark.ml.param.{Param, ParamMap}
|
||||||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] (
|
||||||
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
|
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
|
||||||
.drop(accColName)
|
.drop(accColName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): OneVsRestModel = {
|
||||||
|
val copied = new OneVsRestModel(
|
||||||
|
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String)
|
||||||
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
|
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
|
||||||
copyValues(model)
|
copyValues(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): OneVsRest = {
|
||||||
|
val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
|
||||||
|
if (isDefined(classifier)) {
|
||||||
|
copied.setClassifier($(classifier).copy(extra))
|
||||||
|
}
|
||||||
|
copied
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String)
|
||||||
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
|
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
|
||||||
RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
|
RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
|
|
|
@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String)
|
||||||
metrics.unpersist()
|
metrics.unpersist()
|
||||||
metric
|
metric
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,5 @@ abstract class Evaluator extends Params {
|
||||||
*/
|
*/
|
||||||
def evaluate(dataset: DataFrame): Double
|
def evaluate(dataset: DataFrame): Double
|
||||||
|
|
||||||
override def copy(extra: ParamMap): Evaluator = {
|
override def copy(extra: ParamMap): Evaluator
|
||||||
super.copy(extra).asInstanceOf[Evaluator]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
package org.apache.spark.ml.evaluation
|
package org.apache.spark.ml.evaluation
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.param.{Param, ParamValidators}
|
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
||||||
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
||||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||||
|
@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String)
|
||||||
}
|
}
|
||||||
metric
|
metric
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,4 +83,6 @@ final class Binarizer(override val uid: String)
|
||||||
val outputFields = inputFields :+ attr.toStructField()
|
val outputFields = inputFields :+ attr.toStructField()
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String)
|
||||||
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
|
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
|
||||||
SchemaUtils.appendColumn(schema, prepOutputField(schema))
|
SchemaUtils.appendColumn(schema, prepOutputField(schema))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[feature] object Bucketizer {
|
private[feature] object Bucketizer {
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.UnaryTransformer
|
import org.apache.spark.ml.UnaryTransformer
|
||||||
import org.apache.spark.ml.param.Param
|
import org.apache.spark.ml.param.{ParamMap, Param}
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util.Identifiable
|
||||||
import org.apache.spark.mllib.feature
|
import org.apache.spark.mllib.feature
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup
|
import org.apache.spark.ml.attribute.AttributeGroup
|
||||||
import org.apache.spark.ml.param.{IntParam, ParamValidators}
|
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
||||||
import org.apache.spark.mllib.feature
|
import org.apache.spark.mllib.feature
|
||||||
|
@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
|
||||||
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
|
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
|
||||||
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
|
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
def getMinDocFreq: Int = $(minDocFreq)
|
def getMinDocFreq: Int = $(minDocFreq)
|
||||||
|
|
||||||
/** @group setParam */
|
|
||||||
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validate and transform the input schema.
|
* Validate and transform the input schema.
|
||||||
*/
|
*/
|
||||||
|
@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
|
||||||
|
|
||||||
override def fit(dataset: DataFrame): IDFModel = {
|
override def fit(dataset: DataFrame): IDFModel = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
|
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
|
||||||
|
@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema)
|
validateAndTransformSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): IDF = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -109,4 +111,9 @@ class IDFModel private[ml] (
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema)
|
validateAndTransformSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): IDFModel = {
|
||||||
|
val copied = new IDFModel(uid, idfModel)
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
|
||||||
|
|
||||||
dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
|
dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ import scala.collection.mutable
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.UnaryTransformer
|
import org.apache.spark.ml.UnaryTransformer
|
||||||
import org.apache.spark.ml.param.{IntParam, ParamValidators}
|
import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util.Identifiable
|
||||||
import org.apache.spark.mllib.linalg._
|
import org.apache.spark.mllib.linalg._
|
||||||
import org.apache.spark.sql.types.DataType
|
import org.apache.spark.sql.types.DataType
|
||||||
|
@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new VectorUDT()
|
override protected def outputDataType: DataType = new VectorUDT()
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
|
||||||
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
|
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -125,4 +127,9 @@ class StandardScalerModel private[ml] (
|
||||||
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
|
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
|
||||||
StructType(outputFields)
|
StructType(outputFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): StandardScalerModel = {
|
||||||
|
val copied = new StandardScalerModel(uid, scaler)
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema)
|
validateAndTransformSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -144,4 +146,9 @@ class StringIndexerModel private[ml] (
|
||||||
schema
|
schema
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): StringIndexerModel = {
|
||||||
|
val copied = new StringIndexerModel(uid, labels)
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.SparkException
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
|
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util.Identifiable
|
||||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
|
||||||
|
@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String)
|
||||||
}
|
}
|
||||||
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
|
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
private object VectorAssembler {
|
private object VectorAssembler {
|
||||||
|
|
|
@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
import org.apache.spark.ml.attribute._
|
import org.apache.spark.ml.attribute._
|
||||||
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
|
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
|
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
|
||||||
|
@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
|
||||||
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
|
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
|
||||||
SchemaUtils.appendColumn(schema, $(outputCol), dataType)
|
SchemaUtils.appendColumn(schema, $(outputCol), dataType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
private object VectorIndexer {
|
private object VectorIndexer {
|
||||||
|
@ -399,4 +401,9 @@ class VectorIndexerModel private[ml] (
|
||||||
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
|
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
|
||||||
newAttributeGroup.toStructField()
|
newAttributeGroup.toStructField()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): VectorIndexerModel = {
|
||||||
|
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema)
|
validateAndTransformSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -180,4 +182,9 @@ class Word2VecModel private[ml] (
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema)
|
validateAndTransformSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Word2VecModel = {
|
||||||
|
val copied = new Word2VecModel(uid, wordVectors)
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a copy of this instance with the same UID and some extra params.
|
* Creates a copy of this instance with the same UID and some extra params.
|
||||||
* The default implementation tries to create a new instance with the same UID.
|
* Subclasses should implement this method and set the return type properly.
|
||||||
* Then it copies the embedded and extra parameters over and returns the new instance.
|
*
|
||||||
* Subclasses should override this method if the default approach is not sufficient.
|
* @see [[defaultCopy()]]
|
||||||
*/
|
*/
|
||||||
def copy(extra: ParamMap): Params = {
|
def copy(extra: ParamMap): Params
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default implementation of copy with extra params.
|
||||||
|
* It tries to create a new instance with the same UID.
|
||||||
|
* Then it copies the embedded and extra parameters over and returns the new instance.
|
||||||
|
*/
|
||||||
|
protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
|
||||||
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
|
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
|
||||||
copyValues(that, extra)
|
copyValues(that, extra).asInstanceOf[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -216,6 +216,11 @@ class ALSModel private[ml] (
|
||||||
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
|
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
|
||||||
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
|
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): ALSModel = {
|
||||||
|
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
|
||||||
|
copyValues(copied, extra)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
validateAndTransformSchema(schema)
|
validateAndTransformSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String)
|
||||||
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
|
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
|
||||||
subsamplingRate = 1.0)
|
subsamplingRate = 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
|
|
|
@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String)
|
||||||
val oldModel = oldGBT.run(oldDataset)
|
val oldModel = oldGBT.run(oldDataset)
|
||||||
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
|
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
|
|
|
@ -186,6 +186,8 @@ class LinearRegression(override val uid: String)
|
||||||
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
|
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
|
||||||
copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
|
copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String)
|
||||||
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
|
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
|
||||||
RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
|
RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
|
|
|
@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
|
||||||
est.copy(paramMap).validateParams()
|
est.copy(paramMap).validateParams()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): CrossValidator = {
|
||||||
|
val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
|
||||||
|
if (copied.isDefined(estimator)) {
|
||||||
|
copied.setEstimator(copied.getEstimator.copy(extra))
|
||||||
|
}
|
||||||
|
if (copied.isDefined(evaluator)) {
|
||||||
|
copied.setEvaluator(copied.getEvaluator.copy(extra))
|
||||||
|
}
|
||||||
|
copied
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -159,7 +159,7 @@ private object IDF {
|
||||||
* Represents an IDF model that can transform term frequency vectors.
|
* Represents an IDF model that can transform term frequency vectors.
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
class IDFModel private[mllib] (val idf: Vector) extends Serializable {
|
class IDFModel private[spark] (val idf: Vector) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transforms term frequency (TF) vectors to TF-IDF vectors.
|
* Transforms term frequency (TF) vectors to TF-IDF vectors.
|
||||||
|
|
|
@ -428,7 +428,7 @@ class Word2Vec extends Serializable with Logging {
|
||||||
* Word2Vec model
|
* Word2Vec model
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
class Word2VecModel private[mllib] (
|
class Word2VecModel private[spark] (
|
||||||
model: Map[String, Array[Float]]) extends Serializable with Saveable {
|
model: Map[String, Array[Float]]) extends Serializable with Saveable {
|
||||||
|
|
||||||
// wordList: Ordered list of words obtained from model.
|
// wordList: Ordered list of words obtained from model.
|
||||||
|
|
|
@ -102,4 +102,9 @@ public class JavaTestParams extends JavaParams {
|
||||||
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
|
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
|
||||||
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
|
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JavaTestParams copy(ParamMap extra) {
|
||||||
|
return defaultCopy(extra);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.mockito.Mockito.when
|
||||||
import org.scalatest.mock.MockitoSugar.mock
|
import org.scalatest.mock.MockitoSugar.mock
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.feature.HashingTF
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
|
||||||
|
@ -81,4 +82,13 @@ class PipelineSuite extends SparkFunSuite {
|
||||||
pipeline.fit(dataset)
|
pipeline.fit(dataset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("PipelineModel.copy") {
|
||||||
|
val hashingTF = new HashingTF()
|
||||||
|
.setNumFeatures(100)
|
||||||
|
val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
|
||||||
|
val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
|
||||||
|
require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
|
||||||
|
"copy should handle extra stage params")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,15 +19,15 @@ package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.impl.TreeTests
|
import org.apache.spark.ml.impl.TreeTests
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
|
import org.apache.spark.ml.tree.LeafNode
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
|
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
|
||||||
DecisionTreeSuite => OldDecisionTreeSuite}
|
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
|
||||||
|
|
||||||
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
import DecisionTreeClassifierSuite.compareAPIs
|
import DecisionTreeClassifierSuite.compareAPIs
|
||||||
|
@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
|
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
||||||
|
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests calling train()
|
// Tests calling train()
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.impl.TreeTests
|
import org.apache.spark.ml.impl.TreeTests
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
|
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||||
|
import org.apache.spark.ml.tree.LeafNode
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
|
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
|
@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
|
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new GBTClassifier)
|
||||||
|
val model = new GBTClassificationModel("gbtc",
|
||||||
|
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
|
||||||
|
Array(1.0))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("Binary classification with continuous features: Log Loss") {
|
test("Binary classification with continuous features: Log Loss") {
|
||||||
val categoricalFeatures = Map.empty[Int, Int]
|
val categoricalFeatures = Map.empty[Int, Int]
|
||||||
testCombinations.foreach {
|
testCombinations.foreach {
|
||||||
|
|
|
@ -18,8 +18,9 @@
|
||||||
package org.apache.spark.ml.classification
|
package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.{Vectors, Vector}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new LogisticRegression)
|
||||||
|
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0)
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("logistic regression: default params") {
|
test("logistic regression: default params") {
|
||||||
val lr = new LogisticRegression
|
val lr = new LogisticRegression
|
||||||
assert(lr.getLabelCol === "label")
|
assert(lr.getLabelCol === "label")
|
||||||
|
|
|
@ -19,15 +19,18 @@ package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.attribute.NominalAttribute
|
import org.apache.spark.ml.attribute.NominalAttribute
|
||||||
|
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
||||||
import org.apache.spark.ml.util.MetadataUtils
|
import org.apache.spark.ml.util.MetadataUtils
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
||||||
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||||
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
import org.apache.spark.sql.types.Metadata
|
||||||
|
|
||||||
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
dataset = sqlContext.createDataFrame(rdd)
|
dataset = sqlContext.createDataFrame(rdd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new OneVsRest)
|
||||||
|
val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0)
|
||||||
|
val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("one-vs-rest: default params") {
|
test("one-vs-rest: default params") {
|
||||||
val numClasses = 3
|
val numClasses = 3
|
||||||
val ova = new OneVsRest()
|
val ova = new OneVsRest()
|
||||||
|
@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val output = ovr.fit(dataset).transform(dataset)
|
val output = ovr.fit(dataset).transform(dataset)
|
||||||
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
|
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("OneVsRest.copy and OneVsRestModel.copy") {
|
||||||
|
val lr = new LogisticRegression()
|
||||||
|
.setMaxIter(1)
|
||||||
|
|
||||||
|
val ovr = new OneVsRest()
|
||||||
|
withClue("copy with classifier unset should work") {
|
||||||
|
ovr.copy(ParamMap(lr.maxIter -> 10))
|
||||||
|
}
|
||||||
|
ovr.setClassifier(lr)
|
||||||
|
val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10))
|
||||||
|
require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects")
|
||||||
|
require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
|
||||||
|
"copy should handle extra classifier params")
|
||||||
|
|
||||||
|
val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
|
||||||
|
ovrModel.models.foreach { case m: LogisticRegressionModel =>
|
||||||
|
require(m.getThreshold === 0.1, "copy should handle extra model params")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
|
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.impl.TreeTests
|
import org.apache.spark.ml.impl.TreeTests
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
|
import org.apache.spark.ml.tree.LeafNode
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
|
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
|
||||||
|
@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test suite for [[RandomForestClassifier]].
|
* Test suite for [[RandomForestClassifier]].
|
||||||
*/
|
*/
|
||||||
|
@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
|
compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new RandomForestClassifier)
|
||||||
|
val model = new RandomForestClassificationModel("rfc",
|
||||||
|
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("Binary classification with continuous features:" +
|
test("Binary classification with continuous features:" +
|
||||||
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
|
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
|
||||||
val rf = new RandomForestClassifier()
|
val rf = new RandomForestClassifier()
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.evaluation
|
||||||
|
|
||||||
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
|
|
||||||
|
class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,12 +18,17 @@
|
||||||
package org.apache.spark.ml.evaluation
|
package org.apache.spark.ml.evaluation
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.regression.LinearRegression
|
import org.apache.spark.ml.regression.LinearRegression
|
||||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
|
|
||||||
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new RegressionEvaluator)
|
||||||
|
}
|
||||||
|
|
||||||
test("Regression Evaluator: default params") {
|
test("Regression Evaluator: default params") {
|
||||||
/**
|
/**
|
||||||
* Here is the instruction describing how to export the test data into CSV format
|
* Here is the instruction describing how to export the test data into CSV format
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.ml.feature
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
|
||||||
|
@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
|
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new Binarizer)
|
||||||
|
}
|
||||||
|
|
||||||
test("Binarize continuous features with default parameter") {
|
test("Binarize continuous features with default parameter") {
|
||||||
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
|
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
|
||||||
val dataFrame: DataFrame = sqlContext.createDataFrame(
|
val dataFrame: DataFrame = sqlContext.createDataFrame(
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
|
@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
|
||||||
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new Bucketizer)
|
||||||
|
}
|
||||||
|
|
||||||
test("Bucket continuous features, without -inf,inf") {
|
test("Bucket continuous features, without -inf,inf") {
|
||||||
// Check a set of valid feature values.
|
// Check a set of valid feature values.
|
||||||
val splits = Array(-0.5, 0.0, 0.5)
|
val splits = Array(-0.5, 0.0, 0.5)
|
||||||
|
|
|
@ -28,8 +28,7 @@ import org.apache.spark.util.Utils
|
||||||
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
test("params") {
|
test("params") {
|
||||||
val hashingTF = new HashingTF
|
ParamsSuite.checkParams(new HashingTF)
|
||||||
ParamsSuite.checkParams(hashingTF, 3)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("hashingTF") {
|
test("hashingTF") {
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
package org.apache.spark.ml.feature
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
|
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
|
@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new IDF)
|
||||||
|
val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("compute IDF with default parameter") {
|
test("compute IDF with default parameter") {
|
||||||
val numOfFeatures = 4
|
val numOfFeatures = 4
|
||||||
val data = Array(
|
val data = Array(
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
|
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
indexer.transform(df)
|
indexer.transform(df)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new OneHotEncoder)
|
||||||
|
}
|
||||||
|
|
||||||
test("OneHotEncoder dropLast = false") {
|
test("OneHotEncoder dropLast = false") {
|
||||||
val transformed = stringIndexed()
|
val transformed = stringIndexed()
|
||||||
val encoder = new OneHotEncoder()
|
val encoder = new OneHotEncoder()
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.feature
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.scalatest.exceptions.TestFailedException
|
import org.scalatest.exceptions.TestFailedException
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
@ -27,6 +28,10 @@ import org.apache.spark.sql.Row
|
||||||
|
|
||||||
class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new PolynomialExpansion)
|
||||||
|
}
|
||||||
|
|
||||||
test("Polynomial expansion with default parameter") {
|
test("Polynomial expansion with default parameter") {
|
||||||
val data = Array(
|
val data = Array(
|
||||||
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
|
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
|
||||||
|
|
|
@ -19,10 +19,17 @@ package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
|
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
|
|
||||||
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new StringIndexer)
|
||||||
|
val model = new StringIndexerModel("indexer", Array("a", "b"))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("StringIndexer") {
|
test("StringIndexer") {
|
||||||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
|
||||||
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
val df = sqlContext.createDataFrame(data).toDF("id", "label")
|
||||||
|
|
|
@ -20,15 +20,27 @@ package org.apache.spark.ml.feature
|
||||||
import scala.beans.BeanInfo
|
import scala.beans.BeanInfo
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
|
||||||
@BeanInfo
|
@BeanInfo
|
||||||
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
|
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
|
||||||
|
|
||||||
|
class TokenizerSuite extends SparkFunSuite {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new Tokenizer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
import org.apache.spark.ml.feature.RegexTokenizerSuite._
|
import org.apache.spark.ml.feature.RegexTokenizerSuite._
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new RegexTokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
test("RegexTokenizer") {
|
test("RegexTokenizer") {
|
||||||
val tokenizer0 = new RegexTokenizer()
|
val tokenizer0 = new RegexTokenizer()
|
||||||
.setGaps(false)
|
.setGaps(false)
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||||
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.Row
|
||||||
|
@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col
|
||||||
|
|
||||||
class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new VectorAssembler)
|
||||||
|
}
|
||||||
|
|
||||||
test("assemble") {
|
test("assemble") {
|
||||||
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
import org.apache.spark.ml.feature.VectorAssembler.assemble
|
||||||
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
|
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
|
||||||
|
|
|
@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||||
import org.apache.spark.ml.attribute._
|
import org.apache.spark.ml.attribute._
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
private def getIndexer: VectorIndexer =
|
private def getIndexer: VectorIndexer =
|
||||||
new VectorIndexer().setInputCol("features").setOutputCol("indexed")
|
new VectorIndexer().setInputCol("features").setOutputCol("indexed")
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new VectorIndexer)
|
||||||
|
val model = new VectorIndexerModel("indexer", 1, Map.empty)
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("Cannot fit an empty DataFrame") {
|
test("Cannot fit an empty DataFrame") {
|
||||||
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
|
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
|
||||||
val vectorIndexer = getIndexer
|
val vectorIndexer = getIndexer
|
||||||
|
|
|
@ -18,13 +18,21 @@
|
||||||
package org.apache.spark.ml.feature
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{Row, SQLContext}
|
import org.apache.spark.sql.{Row, SQLContext}
|
||||||
|
import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
|
||||||
|
|
||||||
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
test("params") {
|
||||||
|
ParamsSuite.checkParams(new Word2Vec)
|
||||||
|
val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f))))
|
||||||
|
ParamsSuite.checkParams(model)
|
||||||
|
}
|
||||||
|
|
||||||
test("Word2Vec") {
|
test("Word2Vec") {
|
||||||
val sqlContext = new SQLContext(sc)
|
val sqlContext = new SQLContext(sc)
|
||||||
import sqlContext.implicits._
|
import sqlContext.implicits._
|
||||||
|
|
|
@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite {
|
||||||
object ParamsSuite extends SparkFunSuite {
|
object ParamsSuite extends SparkFunSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
|
* Checks common requirements for [[Params.params]]:
|
||||||
* by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
|
* - params are ordered by names
|
||||||
* the param method name.
|
* - param parent has the same UID as the object's UID
|
||||||
|
* - param name is the same as the param method name
|
||||||
|
* - obj.copy should return the same type as the obj
|
||||||
*/
|
*/
|
||||||
def checkParams(obj: Params, expectedNumParams: Int): Unit = {
|
def checkParams(obj: Params): Unit = {
|
||||||
|
val clazz = obj.getClass
|
||||||
|
|
||||||
val params = obj.params
|
val params = obj.params
|
||||||
require(params.length === expectedNumParams,
|
|
||||||
s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
|
|
||||||
val paramNames = params.map(_.name)
|
val paramNames = params.map(_.name)
|
||||||
require(paramNames === paramNames.sorted)
|
require(paramNames === paramNames.sorted, "params must be ordered by names")
|
||||||
params.foreach { p =>
|
params.foreach { p =>
|
||||||
assert(p.parent === obj.uid)
|
assert(p.parent === obj.uid)
|
||||||
assert(obj.getParam(p.name) === p)
|
assert(obj.getParam(p.name) === p)
|
||||||
|
// TODO: Check that setters return self, which needs special handling for generic types.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
|
||||||
|
val copyReturnType = copyMethod.getReturnType
|
||||||
|
require(copyReturnType === obj.getClass,
|
||||||
|
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H
|
||||||
require(isDefined(inputCol))
|
require(isDefined(inputCol))
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): TestParams = {
|
override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
|
||||||
super.copy(extra).asInstanceOf[TestParams]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,13 +18,15 @@
|
||||||
package org.apache.spark.ml.param.shared
|
package org.apache.spark.ml.param.shared
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.param.Params
|
import org.apache.spark.ml.param.{ParamMap, Params}
|
||||||
|
|
||||||
class SharedParamsSuite extends SparkFunSuite {
|
class SharedParamsSuite extends SparkFunSuite {
|
||||||
|
|
||||||
test("outputCol") {
|
test("outputCol") {
|
||||||
|
|
||||||
class Obj(override val uid: String) extends Params with HasOutputCol
|
class Obj(override val uid: String) extends Params with HasOutputCol {
|
||||||
|
override def copy(extra: ParamMap): Obj = defaultCopy(extra)
|
||||||
|
}
|
||||||
|
|
||||||
val obj = new Obj("obj")
|
val obj = new Obj("obj")
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,8 @@ object CrossValidatorSuite {
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
class MyEvaluator extends Evaluator {
|
class MyEvaluator extends Evaluator {
|
||||||
|
@ -105,5 +107,7 @@ object CrossValidatorSuite {
|
||||||
}
|
}
|
||||||
|
|
||||||
override val uid: String = "eval"
|
override val uid: String = "eval"
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue