[SPARK-18276][ML] ML models should copy the training summary and set parent
## What changes were proposed in this pull request? Only some of the models which contain a training summary currently set the summaries in the copy method. Linear/Logistic regression do, GLR, GMM, KM, and BKM do not. Additionally, these copy methods did not set the parent pointer of the copied model. This patch modifies the copy methods of the four models mentioned above to copy the training summary and set the parent. ## How was this patch tested? Add unit tests in Linear/Logistic/GeneralizedLinear regression and GaussianMixture/KMeans/BisectingKMeans to check the parent pointer of the copied model and check that the copied model has a summary. Author: sethah <seth.hendrickson16@gmail.com> Closes #15773 from sethah/SPARK-18276.
This commit is contained in:
parent
15d3926884
commit
23ce0d1e91
|
@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] (
|
|||
|
||||
@Since("2.0.0")
|
||||
override def copy(extra: ParamMap): BisectingKMeansModel = {
|
||||
val copied = new BisectingKMeansModel(uid, parentModel)
|
||||
copyValues(copied, extra)
|
||||
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
|
||||
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
|
||||
copied.setParent(this.parent)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
|
|
@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] (
|
|||
|
||||
@Since("2.0.0")
|
||||
override def copy(extra: ParamMap): GaussianMixtureModel = {
|
||||
val copied = new GaussianMixtureModel(uid, weights, gaussians)
|
||||
copyValues(copied, extra).setParent(this.parent)
|
||||
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
|
||||
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
|
||||
copied.setParent(this.parent)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
|
|
@ -108,8 +108,9 @@ class KMeansModel private[ml] (
|
|||
|
||||
@Since("1.5.0")
|
||||
override def copy(extra: ParamMap): KMeansModel = {
|
||||
val copied = new KMeansModel(uid, parentModel)
|
||||
copyValues(copied, extra)
|
||||
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
|
||||
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
|
||||
copied.setParent(this.parent)
|
||||
}
|
||||
|
||||
/** @group setParam */
|
||||
|
|
|
@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] (
|
|||
|
||||
@Since("2.0.0")
|
||||
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
|
||||
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
|
||||
.setParent(parent)
|
||||
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
|
||||
extra)
|
||||
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
|
||||
copied.setParent(parent)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] (
|
|||
uid,
|
||||
bestModel.copy(extra).asInstanceOf[Model[_]],
|
||||
validationMetrics.clone())
|
||||
copyValues(copied, extra)
|
||||
copyValues(copied, extra).setParent(parent)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
|
|||
import org.apache.spark.ml.classification.LogisticRegressionSuite._
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
@ -141,6 +141,12 @@ class LogisticRegressionSuite
|
|||
assert(model.getProbabilityCol === "probability")
|
||||
assert(model.intercept !== 0.0)
|
||||
assert(model.hasParent)
|
||||
|
||||
// copied model must have the same parent.
|
||||
MLTestingUtils.checkCopy(model)
|
||||
assert(model.hasSummary)
|
||||
val copiedModel = model.copy(ParamMap.empty)
|
||||
assert(copiedModel.hasSummary)
|
||||
}
|
||||
|
||||
test("empty probabilityCol") {
|
||||
|
@ -251,9 +257,6 @@ class LogisticRegressionSuite
|
|||
mlr.setFitIntercept(false)
|
||||
val mlrModel = mlr.fit(smallMultinomialDataset)
|
||||
assert(mlrModel.interceptVector === Vectors.sparse(3, Seq()))
|
||||
|
||||
// copied model must have the same parent.
|
||||
MLTestingUtils.checkCopy(model)
|
||||
}
|
||||
|
||||
test("logistic regression with setters") {
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
package org.apache.spark.ml.clustering
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
|
@ -41,6 +42,13 @@ class BisectingKMeansSuite
|
|||
assert(bkm.getPredictionCol === "prediction")
|
||||
assert(bkm.getMaxIter === 20)
|
||||
assert(bkm.getMinDivisibleClusterSize === 1.0)
|
||||
val model = bkm.setMaxIter(1).fit(dataset)
|
||||
|
||||
// copied model must have the same parent
|
||||
MLTestingUtils.checkCopy(model)
|
||||
assert(model.hasSummary)
|
||||
val copiedModel = model.copy(ParamMap.empty)
|
||||
assert(copiedModel.hasSummary)
|
||||
}
|
||||
|
||||
test("setter/getter") {
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
package org.apache.spark.ml.clustering
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
|
@ -43,6 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
assert(gm.getPredictionCol === "prediction")
|
||||
assert(gm.getMaxIter === 100)
|
||||
assert(gm.getTol === 0.01)
|
||||
val model = gm.setMaxIter(1).fit(dataset)
|
||||
|
||||
// copied model must have the same parent
|
||||
MLTestingUtils.checkCopy(model)
|
||||
assert(model.hasSummary)
|
||||
val copiedModel = model.copy(ParamMap.empty)
|
||||
assert(copiedModel.hasSummary)
|
||||
}
|
||||
|
||||
test("set parameters") {
|
||||
|
|
|
@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
|
||||
|
@ -47,6 +48,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
|
|||
assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
|
||||
assert(kmeans.getInitSteps === 2)
|
||||
assert(kmeans.getTol === 1e-4)
|
||||
val model = kmeans.setMaxIter(1).fit(dataset)
|
||||
|
||||
// copied model must have the same parent
|
||||
MLTestingUtils.checkCopy(model)
|
||||
assert(model.hasSummary)
|
||||
val copiedModel = model.copy(ParamMap.empty)
|
||||
assert(copiedModel.hasSummary)
|
||||
}
|
||||
|
||||
test("set parameters") {
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._
|
|||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.random._
|
||||
|
@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite
|
|||
|
||||
// copied model must have the same parent.
|
||||
MLTestingUtils.checkCopy(model)
|
||||
assert(model.hasSummary)
|
||||
val copiedModel = model.copy(ParamMap.empty)
|
||||
assert(copiedModel.hasSummary)
|
||||
|
||||
assert(model.getFeaturesCol === "features")
|
||||
assert(model.getPredictionCol === "prediction")
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||
|
@ -143,6 +143,9 @@ class LinearRegressionSuite
|
|||
|
||||
// copied model must have the same parent.
|
||||
MLTestingUtils.checkCopy(model)
|
||||
assert(model.hasSummary)
|
||||
val copiedModel = model.copy(ParamMap.empty)
|
||||
assert(copiedModel.hasSummary)
|
||||
|
||||
model.transform(datasetWithDenseFeature)
|
||||
.select("label", "prediction")
|
||||
|
|
|
@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model}
|
|||
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
|
||||
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
|
||||
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
|
||||
import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.param.shared.HasInputCol
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
@ -78,6 +78,10 @@ class TrainValidationSplitSuite
|
|||
.setTrainRatio(0.5)
|
||||
.setSeed(42L)
|
||||
val cvModel = cv.fit(dataset)
|
||||
|
||||
// copied model must have the same paren.
|
||||
MLTestingUtils.checkCopy(cvModel)
|
||||
|
||||
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
|
||||
assert(parent.getRegParam === 0.001)
|
||||
assert(parent.getMaxIter === 10)
|
||||
|
|
Loading…
Reference in a new issue