[SPARK-22060][ML] Fix CrossValidator/TrainValidationSplit param persist/load bug

## What changes were proposed in this pull request?

Currently the param of CrossValidator/TrainValidationSplit persist/loading is hardcoding, which is different with other ML estimators. This cause persist bug for new added `parallelism` param.

I refactor related code, avoid hardcoding persist/load param. And in the same time, it solve the `parallelism` persisting bug.

This refactoring is very useful because we will add more new params in #19208 , hardcoding param persisting/loading making the thing adding new params very troublesome.

## How was this patch tested?

Test added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19278 from WeichenXu123/fix-tuning-param-bug.
This commit is contained in:
WeichenXu 2017-09-22 18:15:01 -07:00 committed by Joseph K. Bradley
parent 3e6a714c9e
commit f180b65343
6 changed files with 46 additions and 38 deletions

View file

@ -212,14 +212,13 @@ object CrossValidator extends MLReadable[CrossValidator] {
val (metadata, estimator, evaluator, estimatorParamMaps) = val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className) ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int] val cv = new CrossValidator(metadata.uid)
val seed = (metadata.params \ "seed").extract[Long]
new CrossValidator(metadata.uid)
.setEstimator(estimator) .setEstimator(estimator)
.setEvaluator(evaluator) .setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps) .setEstimatorParamMaps(estimatorParamMaps)
.setNumFolds(numFolds) DefaultParamsReader.getAndSetParams(cv, metadata,
.setSeed(seed) skipParams = Option(List("estimatorParamMaps")))
cv
} }
} }
} }
@ -302,17 +301,17 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
val (metadata, estimator, evaluator, estimatorParamMaps) = val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className) ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
model.set(model.estimator, estimator) model.set(model.estimator, estimator)
.set(model.evaluator, evaluator) .set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps) .set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.numFolds, numFolds) DefaultParamsReader.getAndSetParams(model, metadata,
.set(model.seed, seed) skipParams = Option(List("estimatorParamMaps")))
model
} }
} }
} }

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.tuning package org.apache.spark.ml.tuning
import java.io.IOException
import java.util.{List => JList} import java.util.{List => JList}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
@ -207,14 +208,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
val (metadata, estimator, evaluator, estimatorParamMaps) = val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className) ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double] val tvs = new TrainValidationSplit(metadata.uid)
val seed = (metadata.params \ "seed").extract[Long]
new TrainValidationSplit(metadata.uid)
.setEstimator(estimator) .setEstimator(estimator)
.setEvaluator(evaluator) .setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps) .setEstimatorParamMaps(estimatorParamMaps)
.setTrainRatio(trainRatio) DefaultParamsReader.getAndSetParams(tvs, metadata,
.setSeed(seed) skipParams = Option(List("estimatorParamMaps")))
tvs
} }
} }
} }
@ -295,17 +295,17 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
val (metadata, estimator, evaluator, estimatorParamMaps) = val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className) ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
model.set(model.estimator, estimator) model.set(model.estimator, estimator)
.set(model.evaluator, evaluator) .set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps) .set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.trainRatio, trainRatio) DefaultParamsReader.getAndSetParams(model, metadata,
.set(model.seed, seed) skipParams = Option(List("estimatorParamMaps")))
model
} }
} }
} }

View file

@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
}.toSeq }.toSeq
)) ))
val validatorSpecificParams = instance match { val params = instance.extractParamMap().toSeq
case cv: CrossValidatorParams => val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) val jsonParams = render(params
case tvs: TrainValidationSplitParams => .filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) .map { case ParamPair(p, v) =>
case _ => p.name -> parse(p.jsonEncode(v))
// This should not happen. }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + )
instance.getClass.getCanonicalName)
}
val jsonParams = validatorSpecificParams ++ List(
"estimatorParamMaps" -> parse(estimatorParamMapsJson),
"seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))

View file

@ -396,17 +396,27 @@ private[ml] object DefaultParamsReader {
/** /**
* Extract Params from metadata, and set them in the instance. * Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. * This works if all Params (except params included by `skipParams` list) implement
* [[org.apache.spark.ml.param.Param.jsonDecode()]].
*
* @param skipParams The params included in `skipParams` won't be set. This is useful if some
* params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
* and need special handling.
* TODO: Move to [[Metadata]] method * TODO: Move to [[Metadata]] method
*/ */
def getAndSetParams(instance: Params, metadata: Metadata): Unit = { def getAndSetParams(
instance: Params,
metadata: Metadata,
skipParams: Option[List[String]] = None): Unit = {
implicit val format = DefaultFormats implicit val format = DefaultFormats
metadata.params match { metadata.params match {
case JObject(pairs) => case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) => pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName) if (skipParams == None || !skipParams.get.contains(paramName)) {
val value = param.jsonDecode(compact(render(jsonValue))) val param = instance.getParam(paramName)
instance.set(param, value) val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
} }
case _ => case _ =>
throw new IllegalArgumentException( throw new IllegalArgumentException(

View file

@ -159,12 +159,15 @@ class CrossValidatorSuite
.setEvaluator(evaluator) .setEvaluator(evaluator)
.setNumFolds(20) .setNumFolds(20)
.setEstimatorParamMaps(paramMaps) .setEstimatorParamMaps(paramMaps)
.setSeed(42L)
.setParallelism(2)
val cv2 = testDefaultReadWrite(cv, testParams = false) val cv2 = testDefaultReadWrite(cv, testParams = false)
assert(cv.uid === cv2.uid) assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getNumFolds === cv2.getNumFolds)
assert(cv.getSeed === cv2.getSeed) assert(cv.getSeed === cv2.getSeed)
assert(cv.getParallelism === cv2.getParallelism)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]

View file

@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap} import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@ -160,11 +160,13 @@ class TrainValidationSplitSuite
.setTrainRatio(0.5) .setTrainRatio(0.5)
.setEstimatorParamMaps(paramMaps) .setEstimatorParamMaps(paramMaps)
.setSeed(42L) .setSeed(42L)
.setParallelism(2)
val tvs2 = testDefaultReadWrite(tvs, testParams = false) val tvs2 = testDefaultReadWrite(tvs, testParams = false)
assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getTrainRatio === tvs2.getTrainRatio)
assert(tvs.getSeed === tvs2.getSeed) assert(tvs.getSeed === tvs2.getSeed)
assert(tvs.getParallelism === tvs2.getParallelism)
ValidatorParamsSuiteHelpers ValidatorParamsSuiteHelpers
.compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)