[SPARK-19953][ML] Random Forest Models use parent UID when being fit

## What changes were proposed in this pull request?

The ML `RandomForestClassificationModel` and `RandomForestRegressionModel` were not using the estimator parent UID when being fit.  This change fixes that so the models can be properly be identified with their parents.

## How was this patch tested?Existing tests.

Added check to verify that model uid matches that of the parent, then renamed `checkCopy` to `checkCopyAndUids` and verified that it was called by one test for each ML algorithm.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #17296 from BryanCutler/rfmodels-use-parent-uid-SPARK-19953.
This commit is contained in:
Bryan Cutler 2017-04-06 09:41:32 +02:00 committed by Nick Pentreath
parent 5142e5d4e0
commit e156b5dd39
41 changed files with 98 additions and 100 deletions

View file

@ -140,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
.map(_.asInstanceOf[DecisionTreeClassificationModel])
val numFeatures = oldDataset.first().features.size
val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
instr.logSuccess(m)
m
}

View file

@ -131,7 +131,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = oldDataset.first().features.size
val m = new RandomForestRegressionModel(trees, numFeatures)
val m = new RandomForestRegressionModel(uid, trees, numFeatures)
instr.logSuccess(m)
m
}

View file

@ -79,7 +79,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setStages(Array(estimator0, transformer1, estimator2, transformer3))
val pipelineModel = pipeline.fit(dataset0)
MLTestingUtils.checkCopy(pipelineModel)
MLTestingUtils.checkCopyAndUids(pipeline, pipelineModel)
assert(pipelineModel.stages.length === 4)
assert(pipelineModel.stages(0).eq(model0))

View file

@ -249,8 +249,7 @@ class DecisionTreeClassifierSuite
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
// copied model must have the same parent.
MLTestingUtils.checkCopy(newTree)
MLTestingUtils.checkCopyAndUids(dt, newTree)
val predictions = newTree.transform(newData)
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)

View file

@ -97,8 +97,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.getProbabilityCol === "probability")
assert(model.hasParent)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(gbt, model)
}
test("setThreshold, getThreshold") {
@ -261,8 +260,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
.setSeed(123)
val model = gbt.fit(df)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(gbt, model)
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)

View file

@ -124,8 +124,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model.hasParent)
assert(model.numFeatures === 2)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(lsvc, model)
}
test("linear svc doesn't fit intercept when fitIntercept is off") {

View file

@ -142,8 +142,7 @@ class LogisticRegressionSuite
assert(model.intercept !== 0.0)
assert(model.hasParent)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(lr, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

View file

@ -74,8 +74,8 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
MLTestingUtils.checkCopy(model)
val result = model.transform(dataset)
MLTestingUtils.checkCopyAndUids(trainer, model)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
assert(p == l)

View file

@ -149,6 +149,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model)
assert(model.hasParent)
MLTestingUtils.checkCopyAndUids(nb, model)
val validationDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()

View file

@ -76,8 +76,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(ova.getPredictionCol === "prediction")
val ovaModel = ova.fit(dataset)
// copied model must have the same parent.
MLTestingUtils.checkCopy(ovaModel)
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
assert(ovaModel.models.length === numClasses)
val transformedDataset = ovaModel.transform(dataset)

View file

@ -141,8 +141,7 @@ class RandomForestClassifierSuite
val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val model = rf.fit(df)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(rf, model)
val predictions = model.transform(df)
.select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)

View file

@ -47,8 +47,7 @@ class BisectingKMeansSuite
assert(bkm.getMinDivisibleClusterSize === 1.0)
val model = bkm.setMaxIter(1).fit(dataset)
// copied model must have the same parent
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(bkm, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

View file

@ -77,8 +77,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(gm.getTol === 0.01)
val model = gm.setMaxIter(1).fit(dataset)
// copied model must have the same parent
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(gm, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

View file

@ -52,8 +52,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(kmeans.getTol === 1e-4)
val model = kmeans.setMaxIter(1).fit(dataset)
// copied model must have the same parent
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(kmeans, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

View file

@ -176,7 +176,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2)
val model = lda.fit(dataset)
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(lda, model)
assert(model.isInstanceOf[LocalLDAModel])
assert(model.vocabSize === vocabSize)
@ -221,7 +221,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2)
val model_ = lda.fit(dataset)
MLTestingUtils.checkCopy(model_)
MLTestingUtils.checkCopyAndUids(lda, model_)
assert(model_.isInstanceOf[DistributedLDAModel])
val model = model_.asInstanceOf[DistributedLDAModel]

View file

@ -94,7 +94,8 @@ class BucketedRandomProjectionLSHSuite
unitVectors.foreach { v: Vector =>
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
}
MLTestingUtils.checkCopy(brpModel)
MLTestingUtils.checkCopyAndUids(brp, brpModel)
}
test("BucketedRandomProjectionLSH: test of LSH property") {

View file

@ -119,7 +119,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
test("Test Chi-Square selector: numTopFeatures") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
ChiSqSelectorSuite.testSelector(selector, dataset)
val model = ChiSqSelectorSuite.testSelector(selector, dataset)
MLTestingUtils.checkCopyAndUids(selector, model)
}
test("Test Chi-Square selector: percentile") {
@ -166,11 +167,13 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
object ChiSqSelectorSuite {
private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = {
selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect()
private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = {
val selectorModel = selector.fit(dataset)
selectorModel.transform(dataset).select("filtered", "topFeature").collect()
.foreach { case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}
selectorModel
}
/**

View file

@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
@ -68,10 +68,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.fit(df)
assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
val cvm = cv.fit(df)
MLTestingUtils.checkCopyAndUids(cv, cvm)
assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
cv.transform(df).select("features", "expected").collect().foreach {
cvm.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.VectorImplicits._
@ -65,10 +65,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val df = data.zip(expected).toSeq.toDF("features", "expected")
val idfModel = new IDF()
val idfEst = new IDF()
.setInputCol("features")
.setOutputCol("idfValue")
.fit(df)
val idfModel = idfEst.fit(df)
MLTestingUtils.checkCopyAndUids(idfEst, idfModel)
idfModel.transform(df).select("idfValue", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>

View file

@ -18,7 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
@ -58,6 +58,8 @@ private[ml] object LSHTest {
val outputCol = model.getOutputCol
val transformedData = model.transform(dataset)
MLTestingUtils.checkCopyAndUids(lsh, model)
// Check output column type
SchemaUtils.checkColumnType(
transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT))

View file

@ -50,8 +50,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1")
}
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(scaler, model)
}
test("MaxAbsScaler read/write") {

View file

@ -63,7 +63,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCol("values")
val model = mh.fit(dataset)
assert(mh.uid === model.uid)
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(mh, model)
}
test("hashFunction") {

View file

@ -53,8 +53,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
assert(vector1.equals(vector2), "Transformed vector is different with expected.")
}
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(scaler, model)
}
test("MinMaxScaler arguments max must be larger than min") {

View file

@ -58,12 +58,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
.setInputCol("features")
.setOutputCol("pca_features")
.setK(3)
.fit(df)
// copied model must have the same parent.
MLTestingUtils.checkCopy(pca)
val pcaModel = pca.fit(df)
pca.transform(df).select("pca_features", "expected").collect().foreach {
MLTestingUtils.checkCopyAndUids(pca, pcaModel)
pcaModel.transform(df).select("pca_features", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}

View file

@ -37,7 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val model = formula.fit(original)
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(formula, model)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = Seq(

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@ -77,10 +77,11 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
test("Standardization with default parameter") {
val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected")
val standardScaler0 = new StandardScaler()
val standardScalerEst0 = new StandardScaler()
.setInputCol("features")
.setOutputCol("standardized_features")
.fit(df0)
val standardScaler0 = standardScalerEst0.fit(df0)
MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0)
assertResult(standardScaler0.transform(df0))
}

View file

@ -45,12 +45,11 @@ class StringIndexerSuite
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val indexerModel = indexer.fit(df)
// copied model must have the same parent.
MLTestingUtils.checkCopy(indexer)
MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
val transformed = indexer.transform(df)
val transformed = indexerModel.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))

View file

@ -114,8 +114,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
val vectorIndexer = getIndexer
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(vectorIndexer, model)
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work

View file

@ -57,15 +57,14 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val docDF = doc.zip(expected).toDF("text", "expected")
val model = new Word2Vec()
val w2v = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.fit(docDF)
val model = w2v.fit(docDF)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(w2v, model)
// These expectations are just magic values, characterizing the current
// behavior. The test needs to be updated to be more general, see SPARK-11502

View file

@ -17,9 +17,10 @@
package org.apache.spark.ml.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@ -121,7 +122,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(fpGrowth, model)
ParamsSuite.checkParams(fpGrowth)
ParamsSuite.checkParams(model)
}
test("read/write") {

View file

@ -409,8 +409,7 @@ class ALSSuite
logInfo(s"Test RMSE is $rmse.")
assert(rmse < targetRMSE)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(als, model)
}
test("exact rank-1 matrix") {

View file

@ -83,8 +83,7 @@ class AFTSurvivalRegressionSuite
.setQuantilesCol("quantiles")
.fit(datasetUnivariate)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(aftr, model)
model.transform(datasetUnivariate)
.select("label", "prediction", "quantiles")

View file

@ -69,11 +69,12 @@ class DecisionTreeRegressorSuite
test("copied model must have the same parent") {
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
val model = new DecisionTreeRegressor()
val dtr = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(8).fit(df)
MLTestingUtils.checkCopy(model)
.setMaxBins(8)
val model = dtr.fit(df)
MLTestingUtils.checkCopyAndUids(dtr, model)
}
test("predictVariance") {

View file

@ -90,8 +90,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
.setMaxIter(2)
val model = gbt.fit(df)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(gbt, model)
val preds = model.transform(df)
val predictions = preds.select("prediction").rdd.map(_.getDouble(0))
// Checks based on SPARK-8736 (to ensure it is not doing classification)

View file

@ -197,8 +197,7 @@ class GeneralizedLinearRegressionSuite
val model = glr.setFamily("gaussian").setLink("identity")
.fit(datasetGaussianIdentity)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(glr, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

View file

@ -93,8 +93,7 @@ class IsotonicRegressionSuite
val model = ir.fit(dataset)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(ir, model)
model.transform(dataset)
.select("label", "features", "prediction", "weight")

View file

@ -148,8 +148,7 @@ class LinearRegressionSuite
assert(lir.getSolver == "auto")
val model = lir.fit(datasetWithDenseFeature)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
MLTestingUtils.checkCopyAndUids(lir, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

View file

@ -90,6 +90,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
val model = rf.fit(df)
MLTestingUtils.checkCopyAndUids(rf, model)
val importances = model.featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)

View file

@ -58,8 +58,7 @@ class CrossValidatorSuite
.setNumFolds(3)
val cvModel = cv.fit(dataset)
// copied model must have the same paren.
MLTestingUtils.checkCopy(cvModel)
MLTestingUtils.checkCopyAndUids(cv, cvModel)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)

View file

@ -45,18 +45,18 @@ class TrainValidationSplitSuite
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new TrainValidationSplit()
val tvs = new TrainValidationSplit()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
.setSeed(42L)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(cv.getTrainRatio === 0.5)
val tvsModel = tvs.fit(dataset)
val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(tvs.getTrainRatio === 0.5)
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.validationMetrics.length === lrParamMaps.length)
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
}
test("train validation with linear regression") {
@ -71,28 +71,27 @@ class TrainValidationSplitSuite
.addGrid(trainer.maxIter, Array(0, 10))
.build()
val eval = new RegressionEvaluator()
val cv = new TrainValidationSplit()
val tvs = new TrainValidationSplit()
.setEstimator(trainer)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
.setSeed(42L)
val cvModel = cv.fit(dataset)
val tvsModel = tvs.fit(dataset)
// copied model must have the same paren.
MLTestingUtils.checkCopy(cvModel)
MLTestingUtils.checkCopyAndUids(tvs, tvsModel)
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
val parent = tvsModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.validationMetrics.length === lrParamMaps.length)
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
eval.setMetricName("r2")
val cvModel2 = cv.fit(dataset)
val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
val tvsModel2 = tvs.fit(dataset)
val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent2.getRegParam === 0.001)
assert(parent2.getMaxIter === 10)
assert(cvModel2.validationMetrics.length === lrParamMaps.length)
assert(tvsModel2.validationMetrics.length === lrParamMaps.length)
}
test("transformSchema should check estimatorParamMaps") {
@ -104,17 +103,17 @@ class TrainValidationSplitSuite
.addGrid(est.inputCol, Array("input1", "input2"))
.build()
val cv = new TrainValidationSplit()
val tvs = new TrainValidationSplit()
.setEstimator(est)
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
cv.transformSchema(new StructType()) // This should pass.
tvs.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
tvs.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
cv.transformSchema(new StructType())
tvs.transformSchema(new StructType())
}
}

View file

@ -31,11 +31,15 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
object MLTestingUtils extends SparkFunSuite {
def checkCopy(model: Model[_]): Unit = {
def checkCopyAndUids[T <: Estimator[_]](estimator: T, model: Model[_]): Unit = {
assert(estimator.uid === model.uid, "Model uid does not match parent estimator")
// copied model must have the same parent
val copied = model.copy(ParamMap.empty)
.asInstanceOf[Model[_]]
assert(copied.parent.uid == model.parent.uid)
assert(copied.parent == model.parent)
assert(copied.parent.uid == model.parent.uid)
}
def checkNumericTypes[M <: Model[M], T <: Estimator[M]](