[SPARK-19985][ML] Fixed copy method for some ML Models
## What changes were proposed in this pull request? Some ML Models were using `defaultCopy` which expects a default constructor, and others were not setting the parent estimator. This change fixes these by creating a new instance of the model and explicitly setting values and parent. ## How was this patch tested? Added `MLTestingUtils.checkCopy` to the offending models to tests to verify the copy is made and parent is set. Author: Bryan Cutler <cutlerb@gmail.com> Closes #17326 from BryanCutler/ml-model-copy-error-SPARK-19985.
This commit is contained in:
parent
93dbfe705f
commit
2a903a1eec
|
@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
|
|||
|
||||
@Since("1.5.0")
|
||||
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
|
||||
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
|
||||
val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent)
|
||||
copyValues(copied, extra)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
|
|
@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml](
|
|||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
|
||||
override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = {
|
||||
val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent)
|
||||
copyValues(copied, extra)
|
||||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
override def write: MLWriter = {
|
||||
|
|
|
@ -86,7 +86,10 @@ class MinHashLSHModel private[ml](
|
|||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
|
||||
override def copy(extra: ParamMap): MinHashLSHModel = {
|
||||
val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
|
||||
copyValues(copied, extra)
|
||||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
|
||||
|
|
|
@ -268,8 +268,10 @@ class RFormulaModel private[feature](
|
|||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
override def copy(extra: ParamMap): RFormulaModel = copyValues(
|
||||
new RFormulaModel(uid, resolvedFormula, pipelineModel))
|
||||
override def copy(extra: ParamMap): RFormulaModel = {
|
||||
val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent)
|
||||
copyValues(copied, extra)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
|
||||
|
|
|
@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite
|
|||
.setMaxIter(100)
|
||||
.setSolver("l-bfgs")
|
||||
val model = trainer.fit(dataset)
|
||||
MLTestingUtils.checkCopy(model)
|
||||
val result = model.transform(dataset)
|
||||
val predictionAndLabels = result.select("prediction", "label").collect()
|
||||
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
|
||||
|
|
|
@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite
|
|||
.setOutputCol("values")
|
||||
.setBucketLength(1.0)
|
||||
.setSeed(12345)
|
||||
val unitVectors = brp.fit(dataset).randUnitVectors
|
||||
val brpModel = brp.fit(dataset)
|
||||
val unitVectors = brpModel.randUnitVectors
|
||||
unitVectors.foreach { v: Vector =>
|
||||
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
|
||||
}
|
||||
MLTestingUtils.checkCopy(brpModel)
|
||||
}
|
||||
|
||||
test("BucketedRandomProjectionLSH: test of LSH property") {
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
|
@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
|||
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
|
||||
}
|
||||
|
||||
test("Model copy and uid checks") {
|
||||
val mh = new MinHashLSH()
|
||||
.setInputCol("keys")
|
||||
.setOutputCol("values")
|
||||
val model = mh.fit(dataset)
|
||||
assert(mh.uid === model.uid)
|
||||
MLTestingUtils.checkCopy(model)
|
||||
}
|
||||
|
||||
test("hashFunction") {
|
||||
val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
|
||||
val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
|
||||
|
|
|
@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
|||
val formula = new RFormula().setFormula("id ~ v1 + v2")
|
||||
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
|
||||
val model = formula.fit(original)
|
||||
MLTestingUtils.checkCopy(model)
|
||||
val result = model.transform(original)
|
||||
val resultSchema = model.transformSchema(original.schema)
|
||||
val expected = Seq(
|
||||
|
|
Loading…
Reference in a new issue