[SPARK-22060][ML] Fix CrossValidator/TrainValidationSplit param persist/load bug
## What changes were proposed in this pull request? Currently the param of CrossValidator/TrainValidationSplit persist/loading is hardcoding, which is different with other ML estimators. This cause persist bug for new added `parallelism` param. I refactor related code, avoid hardcoding persist/load param. And in the same time, it solve the `parallelism` persisting bug. This refactoring is very useful because we will add more new params in #19208 , hardcoding param persisting/loading making the thing adding new params very troublesome. ## How was this patch tested? Test added. Author: WeichenXu <weichen.xu@databricks.com> Closes #19278 from WeichenXu123/fix-tuning-param-bug.
This commit is contained in:
parent
3e6a714c9e
commit
f180b65343
|
@ -212,14 +212,13 @@ object CrossValidator extends MLReadable[CrossValidator] {
|
||||||
|
|
||||||
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
||||||
ValidatorParams.loadImpl(path, sc, className)
|
ValidatorParams.loadImpl(path, sc, className)
|
||||||
val numFolds = (metadata.params \ "numFolds").extract[Int]
|
val cv = new CrossValidator(metadata.uid)
|
||||||
val seed = (metadata.params \ "seed").extract[Long]
|
|
||||||
new CrossValidator(metadata.uid)
|
|
||||||
.setEstimator(estimator)
|
.setEstimator(estimator)
|
||||||
.setEvaluator(evaluator)
|
.setEvaluator(evaluator)
|
||||||
.setEstimatorParamMaps(estimatorParamMaps)
|
.setEstimatorParamMaps(estimatorParamMaps)
|
||||||
.setNumFolds(numFolds)
|
DefaultParamsReader.getAndSetParams(cv, metadata,
|
||||||
.setSeed(seed)
|
skipParams = Option(List("estimatorParamMaps")))
|
||||||
|
cv
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -302,17 +301,17 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
|
||||||
|
|
||||||
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
||||||
ValidatorParams.loadImpl(path, sc, className)
|
ValidatorParams.loadImpl(path, sc, className)
|
||||||
val numFolds = (metadata.params \ "numFolds").extract[Int]
|
|
||||||
val seed = (metadata.params \ "seed").extract[Long]
|
|
||||||
val bestModelPath = new Path(path, "bestModel").toString
|
val bestModelPath = new Path(path, "bestModel").toString
|
||||||
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
|
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
|
||||||
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
|
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
|
||||||
|
|
||||||
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
|
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
|
||||||
model.set(model.estimator, estimator)
|
model.set(model.estimator, estimator)
|
||||||
.set(model.evaluator, evaluator)
|
.set(model.evaluator, evaluator)
|
||||||
.set(model.estimatorParamMaps, estimatorParamMaps)
|
.set(model.estimatorParamMaps, estimatorParamMaps)
|
||||||
.set(model.numFolds, numFolds)
|
DefaultParamsReader.getAndSetParams(model, metadata,
|
||||||
.set(model.seed, seed)
|
skipParams = Option(List("estimatorParamMaps")))
|
||||||
|
model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.tuning
|
package org.apache.spark.ml.tuning
|
||||||
|
|
||||||
|
import java.io.IOException
|
||||||
import java.util.{List => JList}
|
import java.util.{List => JList}
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
@ -207,14 +208,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
|
||||||
|
|
||||||
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
||||||
ValidatorParams.loadImpl(path, sc, className)
|
ValidatorParams.loadImpl(path, sc, className)
|
||||||
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
|
val tvs = new TrainValidationSplit(metadata.uid)
|
||||||
val seed = (metadata.params \ "seed").extract[Long]
|
|
||||||
new TrainValidationSplit(metadata.uid)
|
|
||||||
.setEstimator(estimator)
|
.setEstimator(estimator)
|
||||||
.setEvaluator(evaluator)
|
.setEvaluator(evaluator)
|
||||||
.setEstimatorParamMaps(estimatorParamMaps)
|
.setEstimatorParamMaps(estimatorParamMaps)
|
||||||
.setTrainRatio(trainRatio)
|
DefaultParamsReader.getAndSetParams(tvs, metadata,
|
||||||
.setSeed(seed)
|
skipParams = Option(List("estimatorParamMaps")))
|
||||||
|
tvs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -295,17 +295,17 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
|
||||||
|
|
||||||
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
val (metadata, estimator, evaluator, estimatorParamMaps) =
|
||||||
ValidatorParams.loadImpl(path, sc, className)
|
ValidatorParams.loadImpl(path, sc, className)
|
||||||
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
|
|
||||||
val seed = (metadata.params \ "seed").extract[Long]
|
|
||||||
val bestModelPath = new Path(path, "bestModel").toString
|
val bestModelPath = new Path(path, "bestModel").toString
|
||||||
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
|
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
|
||||||
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
|
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
|
||||||
|
|
||||||
val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
|
val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
|
||||||
model.set(model.estimator, estimator)
|
model.set(model.estimator, estimator)
|
||||||
.set(model.evaluator, evaluator)
|
.set(model.evaluator, evaluator)
|
||||||
.set(model.estimatorParamMaps, estimatorParamMaps)
|
.set(model.estimatorParamMaps, estimatorParamMaps)
|
||||||
.set(model.trainRatio, trainRatio)
|
DefaultParamsReader.getAndSetParams(model, metadata,
|
||||||
.set(model.seed, seed)
|
skipParams = Option(List("estimatorParamMaps")))
|
||||||
|
model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
|
||||||
}.toSeq
|
}.toSeq
|
||||||
))
|
))
|
||||||
|
|
||||||
val validatorSpecificParams = instance match {
|
val params = instance.extractParamMap().toSeq
|
||||||
case cv: CrossValidatorParams =>
|
val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
|
||||||
List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
|
val jsonParams = render(params
|
||||||
case tvs: TrainValidationSplitParams =>
|
.filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
|
||||||
List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
|
.map { case ParamPair(p, v) =>
|
||||||
case _ =>
|
p.name -> parse(p.jsonEncode(v))
|
||||||
// This should not happen.
|
}.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
|
||||||
throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
|
)
|
||||||
instance.getClass.getCanonicalName)
|
|
||||||
}
|
|
||||||
|
|
||||||
val jsonParams = validatorSpecificParams ++ List(
|
|
||||||
"estimatorParamMaps" -> parse(estimatorParamMapsJson),
|
|
||||||
"seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
|
|
||||||
|
|
||||||
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
|
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
|
||||||
|
|
||||||
|
|
|
@ -396,17 +396,27 @@ private[ml] object DefaultParamsReader {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extract Params from metadata, and set them in the instance.
|
* Extract Params from metadata, and set them in the instance.
|
||||||
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
* This works if all Params (except params included by `skipParams` list) implement
|
||||||
|
* [[org.apache.spark.ml.param.Param.jsonDecode()]].
|
||||||
|
*
|
||||||
|
* @param skipParams The params included in `skipParams` won't be set. This is useful if some
|
||||||
|
* params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
|
||||||
|
* and need special handling.
|
||||||
* TODO: Move to [[Metadata]] method
|
* TODO: Move to [[Metadata]] method
|
||||||
*/
|
*/
|
||||||
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
|
def getAndSetParams(
|
||||||
|
instance: Params,
|
||||||
|
metadata: Metadata,
|
||||||
|
skipParams: Option[List[String]] = None): Unit = {
|
||||||
implicit val format = DefaultFormats
|
implicit val format = DefaultFormats
|
||||||
metadata.params match {
|
metadata.params match {
|
||||||
case JObject(pairs) =>
|
case JObject(pairs) =>
|
||||||
pairs.foreach { case (paramName, jsonValue) =>
|
pairs.foreach { case (paramName, jsonValue) =>
|
||||||
val param = instance.getParam(paramName)
|
if (skipParams == None || !skipParams.get.contains(paramName)) {
|
||||||
val value = param.jsonDecode(compact(render(jsonValue)))
|
val param = instance.getParam(paramName)
|
||||||
instance.set(param, value)
|
val value = param.jsonDecode(compact(render(jsonValue)))
|
||||||
|
instance.set(param, value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
|
|
@ -159,12 +159,15 @@ class CrossValidatorSuite
|
||||||
.setEvaluator(evaluator)
|
.setEvaluator(evaluator)
|
||||||
.setNumFolds(20)
|
.setNumFolds(20)
|
||||||
.setEstimatorParamMaps(paramMaps)
|
.setEstimatorParamMaps(paramMaps)
|
||||||
|
.setSeed(42L)
|
||||||
|
.setParallelism(2)
|
||||||
|
|
||||||
val cv2 = testDefaultReadWrite(cv, testParams = false)
|
val cv2 = testDefaultReadWrite(cv, testParams = false)
|
||||||
|
|
||||||
assert(cv.uid === cv2.uid)
|
assert(cv.uid === cv2.uid)
|
||||||
assert(cv.getNumFolds === cv2.getNumFolds)
|
assert(cv.getNumFolds === cv2.getNumFolds)
|
||||||
assert(cv.getSeed === cv2.getSeed)
|
assert(cv.getSeed === cv2.getSeed)
|
||||||
|
assert(cv.getParallelism === cv2.getParallelism)
|
||||||
|
|
||||||
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
|
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
|
||||||
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
|
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
|
||||||
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
|
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
|
||||||
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
|
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.apache.spark.ml.param.{ParamMap}
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.param.shared.HasInputCol
|
import org.apache.spark.ml.param.shared.HasInputCol
|
||||||
import org.apache.spark.ml.regression.LinearRegression
|
import org.apache.spark.ml.regression.LinearRegression
|
||||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||||
|
@ -160,11 +160,13 @@ class TrainValidationSplitSuite
|
||||||
.setTrainRatio(0.5)
|
.setTrainRatio(0.5)
|
||||||
.setEstimatorParamMaps(paramMaps)
|
.setEstimatorParamMaps(paramMaps)
|
||||||
.setSeed(42L)
|
.setSeed(42L)
|
||||||
|
.setParallelism(2)
|
||||||
|
|
||||||
val tvs2 = testDefaultReadWrite(tvs, testParams = false)
|
val tvs2 = testDefaultReadWrite(tvs, testParams = false)
|
||||||
|
|
||||||
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
|
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
|
||||||
assert(tvs.getSeed === tvs2.getSeed)
|
assert(tvs.getSeed === tvs2.getSeed)
|
||||||
|
assert(tvs.getParallelism === tvs2.getParallelism)
|
||||||
|
|
||||||
ValidatorParamsSuiteHelpers
|
ValidatorParamsSuiteHelpers
|
||||||
.compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
|
.compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
|
||||||
|
|
Loading…
Reference in a new issue