[SPARK-8151] [MLLIB] pipeline components should correctly implement copy

Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6622 from mengxr/SPARK-8087 and squashes the following commits:

0e4c8c4 [Xiangrui Meng] fix merge issues
26fc1f0 [Xiangrui Meng] address comments
e607a04 [Xiangrui Meng] merge master
b85b57e [Xiangrui Meng] fix examples/compile
d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy
84ec278 [Xiangrui Meng] remove setter checks due to generics
2cf2ed0 [Xiangrui Meng] snapshot
291814f [Xiangrui Meng] OneVsRest.copy
1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages

(cherry picked from commit 43c7ec6384)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
Xiangrui Meng 2015-06-19 09:46:51 -07:00
parent 164b9d32e7
commit 1f2dafb77f
62 changed files with 351 additions and 55 deletions

View file

@ -156,6 +156,11 @@ class MyJavaLogisticRegression
// Create a model, and return it.
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}
@Override
public MyJavaLogisticRegression copy(ParamMap extra) {
return defaultCopy(extra);
}
}
/**

View file

@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String)
// Create a model, and return it.
new MyLogisticRegressionModel(uid, weights).setParent(this)
}
override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
}
/**

View file

@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
paramMaps.map(fit(dataset, _))
}
override def copy(extra: ParamMap): Estimator[M] = {
super.copy(extra).asInstanceOf[Estimator[M]]
}
override def copy(extra: ParamMap): Estimator[M]
}

View file

@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
/** Indicates whether this [[Model]] has a corresponding parent. */
def hasParent: Boolean = parent != null
override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
}
override def copy(extra: ParamMap): M
}

View file

@ -63,9 +63,7 @@ abstract class PipelineStage extends Params with Logging {
outputSchema
}
override def copy(extra: ParamMap): PipelineStage = {
super.copy(extra).asInstanceOf[PipelineStage]
}
override def copy(extra: ParamMap): PipelineStage
}
/**
@ -190,6 +188,6 @@ class PipelineModel private[ml] (
}
override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages)
new PipelineModel(uid, stages.map(_.copy(extra)))
}
}

View file

@ -90,9 +90,7 @@ abstract class Predictor[
copyValues(train(dataset).setParent(this))
}
override def copy(extra: ParamMap): Learner = {
super.copy(extra).asInstanceOf[Learner]
}
override def copy(extra: ParamMap): Learner
/**
* Train a model using the given dataset and parameters.

View file

@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
*/
def transform(dataset: DataFrame): DataFrame
override def copy(extra: ParamMap): Transformer = {
super.copy(extra).asInstanceOf[Transformer]
}
override def copy(extra: ParamMap): Transformer
}
/**
@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
dataset.withColumn($(outputCol),
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
}
override def copy(extra: ParamMap): T = defaultCopy(extra)
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils

View file

@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0)
}
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
}
@Experimental

View file

@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
}
@Experimental

View file

@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String)
new LogisticRegressionModel(uid, weights.compressed, intercept)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
}
/**

View file

@ -24,7 +24,7 @@ import scala.language.existentials
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] (
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
.drop(accColName)
}
override def copy(extra: ParamMap): OneVsRestModel = {
val copied = new OneVsRestModel(
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra)
}
}
/**
@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String)
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
copyValues(model)
}
override def copy(extra: ParamMap): OneVsRest = {
val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
if (isDefined(classifier)) {
copied.setClassifier($(classifier).copy(extra))
}
copied
}
}

View file

@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}
@Experimental

View file

@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String)
metrics.unpersist()
metric
}
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}

View file

@ -46,7 +46,5 @@ abstract class Evaluator extends Params {
*/
def evaluate(dataset: DataFrame): Double
override def copy(extra: ParamMap): Evaluator = {
super.copy(extra).asInstanceOf[Evaluator]
}
override def copy(extra: ParamMap): Evaluator
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.{Param, ParamValidators}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String)
}
metric
}
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}

View file

@ -83,4 +83,6 @@ final class Binarizer(override val uid: String)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

View file

@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String)
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
}
private[feature] object Bucketizer {

View file

@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

View file

@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
/** @group getParam */
def getMinDocFreq: Int = $(minDocFreq)
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
/**
* Validate and transform the input schema.
*/
@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
override def fit(dataset: DataFrame): IDFModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): IDF = defaultCopy(extra)
}
/**
@ -109,4 +111,9 @@ class IDFModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): IDFModel = {
val copied = new IDFModel(uid, idfModel)
copyValues(copied, extra)
}
}

View file

@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
}
override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
}

View file

@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT()
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
}
/**

View file

@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
}
/**
@ -125,4 +127,9 @@ class StandardScalerModel private[ml] (
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, scaler)
copyValues(copied, extra)
}
}

View file

@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
}
/**
@ -144,4 +146,9 @@ class StringIndexerModel private[ml] (
schema
}
}
override def copy(extra: ParamMap): StringIndexerModel = {
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra)
}
}

View file

@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}
/**
@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String)
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
}

View file

@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String)
}
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
}
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}
private object VectorAssembler {

View file

@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
SchemaUtils.appendColumn(schema, $(outputCol), dataType)
}
override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
}
private object VectorIndexer {
@ -399,4 +401,9 @@ class VectorIndexerModel private[ml] (
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
newAttributeGroup.toStructField()
}
override def copy(extra: ParamMap): VectorIndexerModel = {
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
copyValues(copied, extra)
}
}

View file

@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
}
/**
@ -180,4 +182,9 @@ class Word2VecModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): Word2VecModel = {
val copied = new Word2VecModel(uid, wordVectors)
copyValues(copied, extra)
}
}

View file

@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable {
/**
* Creates a copy of this instance with the same UID and some extra params.
* The default implementation tries to create a new instance with the same UID.
* Then it copies the embedded and extra parameters over and returns the new instance.
* Subclasses should override this method if the default approach is not sufficient.
* Subclasses should implement this method and set the return type properly.
*
* @see [[defaultCopy()]]
*/
def copy(extra: ParamMap): Params = {
def copy(extra: ParamMap): Params
/**
* Default implementation of copy with extra params.
* It tries to create a new instance with the same UID.
* Then it copies the embedded and extra parameters over and returns the new instance.
*/
protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
copyValues(that, extra)
copyValues(that, extra).asInstanceOf[T]
}
/**

View file

@ -216,6 +216,11 @@ class ALSModel private[ml] (
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
override def copy(extra: ParamMap): ALSModel = {
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
copyValues(copied, extra)
}
}
@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
}
/**

View file

@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
subsamplingRate = 1.0)
}
override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra)
}
@Experimental

View file

@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
}
@Experimental

View file

@ -186,6 +186,8 @@ class LinearRegression(override val uid: String)
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
}
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
}
/**

View file

@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
}
@Experimental

View file

@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
est.copy(paramMap).validateParams()
}
}
override def copy(extra: ParamMap): CrossValidator = {
val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
if (copied.isDefined(estimator)) {
copied.setEstimator(copied.getEstimator.copy(extra))
}
if (copied.isDefined(evaluator)) {
copied.setEvaluator(copied.getEvaluator.copy(extra))
}
copied
}
}
/**

View file

@ -159,7 +159,7 @@ private object IDF {
* Represents an IDF model that can transform term frequency vectors.
*/
@Experimental
class IDFModel private[mllib] (val idf: Vector) extends Serializable {
class IDFModel private[spark] (val idf: Vector) extends Serializable {
/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.

View file

@ -428,7 +428,7 @@ class Word2Vec extends Serializable with Logging {
* Word2Vec model
*/
@Experimental
class Word2VecModel private[mllib] (
class Word2VecModel private[spark] (
model: Map[String, Array[Float]]) extends Serializable with Saveable {
// wordList: Ordered list of words obtained from model.

View file

@ -102,4 +102,9 @@ public class JavaTestParams extends JavaParams {
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
@Override
public JavaTestParams copy(ParamMap extra) {
return defaultCopy(extra);
}
}

View file

@ -22,6 +22,7 @@ import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
@ -81,4 +82,13 @@ class PipelineSuite extends SparkFunSuite {
pipeline.fit(dataset)
}
}
test("PipelineModel.copy") {
val hashingTF = new HashingTF()
.setNumFeatures(100)
val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
"copy should handle extra stage params")
}
}

View file

@ -19,15 +19,15 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
}
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
ParamsSuite.checkParams(model)
}
/////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////

View file

@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
Array(1.0))
ParamsSuite.checkParams(model)
}
test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {

View file

@ -18,8 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
test("params") {
ParamsSuite.checkParams(new LogisticRegression)
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0)
ParamsSuite.checkParams(model)
}
test("logistic regression: default params") {
val lr = new LogisticRegression
assert(lr.getLabelCol === "label")

View file

@ -19,15 +19,18 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(rdd)
}
test("params") {
ParamsSuite.checkParams(new OneVsRest)
val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0)
val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel))
ParamsSuite.checkParams(model)
}
test("one-vs-rest: default params") {
val numClasses = 3
val ova = new OneVsRest()
@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
val ovr = new OneVsRest()
withClue("copy with classifier unset should work") {
ovr.copy(ParamMap(lr.maxIter -> 10))
}
ovr.setClassifier(lr)
val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10))
require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects")
require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
"copy should handle extra classifier params")
val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
ovrModel.models.foreach { case m: LogisticRegressionModel =>
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {

View file

@ -19,6 +19,8 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestClassifier]].
*/
@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
}
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
ParamsSuite.checkParams(model)
}
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestClassifier()

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
test("params") {
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
}
}

View file

@ -18,12 +18,17 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new RegressionEvaluator)
}
test("Regression Evaluator: default params") {
/**
* Here is the instruction describing how to export the test data into CSV format

View file

@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
}
test("params") {
ParamsSuite.checkParams(new Binarizer)
}
test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(

View file

@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row}
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new Bucketizer)
}
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)

View file

@ -28,8 +28,7 @@ import org.apache.spark.util.Utils
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
val hashingTF = new HashingTF
ParamsSuite.checkParams(hashingTF, 3)
ParamsSuite.checkParams(new HashingTF)
}
test("hashingTF") {

View file

@ -18,6 +18,8 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
test("params") {
ParamsSuite.checkParams(new IDF)
val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
ParamsSuite.checkParams(model)
}
test("compute IDF with default parameter") {
val numOfFeatures = 4
val data = Array(

View file

@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
indexer.transform(df)
}
test("params") {
ParamsSuite.checkParams(new OneHotEncoder)
}
test("OneHotEncoder dropLast = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
@ -27,6 +28,10 @@ import org.apache.spark.sql.Row
class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new PolynomialExpansion)
}
test("Polynomial expansion with default parameter") {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),

View file

@ -19,10 +19,17 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new StringIndexer)
val model = new StringIndexerModel("indexer", Array("a", "b"))
ParamsSuite.checkParams(model)
}
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")

View file

@ -20,15 +20,27 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
class TokenizerSuite extends SparkFunSuite {
test("params") {
ParamsSuite.checkParams(new Tokenizer)
}
}
class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
test("params") {
ParamsSuite.checkParams(new RegexTokenizer)
}
test("RegexTokenizer") {
val tokenizer0 = new RegexTokenizer()
.setGaps(false)

View file

@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col
class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
}
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))

View file

@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
private def getIndexer: VectorIndexer =
new VectorIndexer().setInputCol("features").setOutputCol("indexed")
test("params") {
ParamsSuite.checkParams(new VectorIndexer)
val model = new VectorIndexerModel("indexer", 1, Map.empty)
ParamsSuite.checkParams(model)
}
test("Cannot fit an empty DataFrame") {
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer

View file

@ -18,13 +18,21 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new Word2Vec)
val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f))))
ParamsSuite.checkParams(model)
}
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

View file

@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite {
object ParamsSuite extends SparkFunSuite {
/**
* Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
* by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
* the param method name.
* Checks common requirements for [[Params.params]]:
* - params are ordered by names
* - param parent has the same UID as the object's UID
* - param name is the same as the param method name
* - obj.copy should return the same type as the obj
*/
def checkParams(obj: Params, expectedNumParams: Int): Unit = {
def checkParams(obj: Params): Unit = {
val clazz = obj.getClass
val params = obj.params
require(params.length === expectedNumParams,
s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
val paramNames = params.map(_.name)
require(paramNames === paramNames.sorted)
require(paramNames === paramNames.sorted, "params must be ordered by names")
params.foreach { p =>
assert(p.parent === obj.uid)
assert(obj.getParam(p.name) === p)
// TODO: Check that setters return self, which needs special handling for generic types.
}
val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
val copyReturnType = copyMethod.getReturnType
require(copyReturnType === obj.getClass,
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}
}

View file

@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H
require(isDefined(inputCol))
}
override def copy(extra: ParamMap): TestParams = {
super.copy(extra).asInstanceOf[TestParams]
}
override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
}

View file

@ -18,13 +18,15 @@
package org.apache.spark.ml.param.shared
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.param.{ParamMap, Params}
class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
class Obj(override val uid: String) extends Params with HasOutputCol
class Obj(override val uid: String) extends Params with HasOutputCol {
override def copy(extra: ParamMap): Obj = defaultCopy(extra)
}
val obj = new Obj("obj")

View file

@ -96,6 +96,8 @@ object CrossValidatorSuite {
override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
}
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
}
class MyEvaluator extends Evaluator {
@ -105,5 +107,7 @@ object CrossValidatorSuite {
}
override val uid: String = "eval"
override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
}
}