[SPARK-19985][ML] Fixed copy method for some ML Models

## What changes were proposed in this pull request?
Some ML Models were using `defaultCopy` which expects a default constructor, and others were not setting the parent estimator.  This change fixes these by creating a new instance of the model and explicitly setting values and parent.

## How was this patch tested?
Added `MLTestingUtils.checkCopy` to the offending models to tests to verify the copy is made and parent is set.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #17326 from BryanCutler/ml-model-copy-error-SPARK-19985.
This commit is contained in:
Bryan Cutler 2017-04-03 10:56:54 +02:00 committed by Nick Pentreath
parent 93dbfe705f
commit 2a903a1eec
8 changed files with 30 additions and 8 deletions

View file

@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent)
copyValues(copied, extra)
}
@Since("2.0.0")

View file

@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml](
}
@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = {
val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent)
copyValues(copied, extra)
}
@Since("2.1.0")
override def write: MLWriter = {

View file

@ -86,7 +86,10 @@ class MinHashLSHModel private[ml](
}
@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
override def copy(extra: ParamMap): MinHashLSHModel = {
val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
copyValues(copied, extra)
}
@Since("2.1.0")
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)

View file

@ -268,8 +268,10 @@ class RFormulaModel private[feature](
}
@Since("1.5.0")
override def copy(extra: ParamMap): RFormulaModel = copyValues(
new RFormulaModel(uid, resolvedFormula, pipelineModel))
override def copy(extra: ParamMap): RFormulaModel = {
val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent)
copyValues(copied, extra)
}
@Since("2.0.0")
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"

View file

@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
MLTestingUtils.checkCopy(model)
val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>

View file

@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite
.setOutputCol("values")
.setBucketLength(1.0)
.setSeed(12345)
val unitVectors = brp.fit(dataset).randUnitVectors
val brpModel = brp.fit(dataset)
val unitVectors = brpModel.randUnitVectors
unitVectors.foreach { v: Vector =>
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
}
MLTestingUtils.checkCopy(brpModel)
}
test("BucketedRandomProjectionLSH: test of LSH property") {

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}
test("Model copy and uid checks") {
val mh = new MinHashLSH()
.setInputCol("keys")
.setOutputCol("values")
val model = mh.fit(dataset)
assert(mh.uid === model.uid)
MLTestingUtils.checkCopy(model)
}
test("hashFunction") {
val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))

View file

@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val model = formula.fit(original)
MLTestingUtils.checkCopy(model)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = Seq(