From a6fda0bfc16a13b28b1cecc96f1ff91363089144 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 22 Nov 2015 21:48:48 -0800 Subject: [PATCH] [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9848 from jkbradley/cv-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +-- .../BinaryClassificationEvaluator.scala | 11 +- .../MulticlassClassificationEvaluator.scala | 12 +- .../ml/evaluation/RegressionEvaluator.scala | 11 +- .../apache/spark/ml/recommendation/ALS.scala | 14 +- .../spark/ml/tuning/CrossValidator.scala | 229 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 48 ++-- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../BinaryClassificationEvaluatorSuite.scala | 13 +- ...lticlassClassificationEvaluatorSuite.scala | 13 +- .../evaluation/RegressionEvaluatorSuite.scala | 12 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 202 ++++++++++++++- 12 files changed, 522 insertions(+), 85 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6f15b37abc..4b2b3f8489 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] { stages: Array[PipelineStage], sc: SparkContext, path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] { implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) } (metadata.uid, stages) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca8..bfb70963b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a..c44db0ec59 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d..daaa174a08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4d35177ad9..b798aa1fab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] { private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - val extraMetadata = render("rank" -> instance.rank) + val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] { override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats - val rank: Int = metadata.extraMetadata match { - case Some(m: JValue) => - (m \ "rank").extract[Int] - case None => - throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + - s" ${metadata.metadataStr}") - } - + val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString val userFactors = sqlContext.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed8..83a9048374 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,17 +18,24 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.{JObject, DefaultFormats} +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { */ @Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { + with CrossValidatorParams with MLWritable with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** @@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { override def validateParams(): Unit = { bestModel.validateParams() @@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ff9322dba1..8484b1f801 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata( instance: Params, path: String, sc: SparkContext, - extraMetadata: Option[JValue] = None): Unit = { + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = params.map { case ParamPair(p, v) => + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) - }.toList - val metadata = ("class" -> cls) ~ + }.toList)) + val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - ("extraMetadata" -> extraMetadata) + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] - * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] - * @param metadataStr Full metadata file String (for debugging) + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) */ case class Metadata( className: String, @@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, - extraMetadata: Option[JValue], - metadataStr: String) + metadata: JValue, + metadataJson: String) /** * Load metadata from file. @@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" - val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } /** @@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader { } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 12aba6bc6d..8c86767456 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml -import java.io.File - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe66..a535c1218e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b3..7ee65975d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da323..60886bf77d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a0..dd6366050c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel]