[SPARK-6791][ML] Add read/write for CrossValidator and Evaluators
I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #9848 from jkbradley/cv-io.
This commit is contained in:
parent
fe89c1817d
commit
a6fda0bfc1
|
@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter
|
|||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
|
@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] {
|
|||
stages: Array[PipelineStage],
|
||||
sc: SparkContext,
|
||||
path: String): Unit = {
|
||||
// Copied and edited from DefaultParamsWriter.saveMetadata
|
||||
// TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
|
||||
val uid = instance.uid
|
||||
val cls = instance.getClass.getName
|
||||
val stageUids = stages.map(_.uid)
|
||||
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
|
||||
val metadata = ("class" -> cls) ~
|
||||
("timestamp" -> System.currentTimeMillis()) ~
|
||||
("sparkVersion" -> sc.version) ~
|
||||
("uid" -> uid) ~
|
||||
("paramMap" -> jsonParams)
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataJson = compact(render(metadata))
|
||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
|
||||
|
||||
// Save stages
|
||||
val stagesDir = new Path(path, "stages").toString
|
||||
|
@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] {
|
|||
|
||||
implicit val format = DefaultFormats
|
||||
val stagesDir = new Path(path, "stages").toString
|
||||
val stageUids: Array[String] = metadata.params match {
|
||||
case JObject(pairs) =>
|
||||
if (pairs.length != 1) {
|
||||
// Should not happen unless file is corrupted or we have a bug.
|
||||
throw new RuntimeException(
|
||||
s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
|
||||
}
|
||||
pairs.head match {
|
||||
case ("stageUids", jsonValue) =>
|
||||
jsonValue.extract[Seq[String]].toArray
|
||||
case (paramName, jsonValue) =>
|
||||
// Should not happen unless file is corrupted or we have a bug.
|
||||
throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
|
||||
s" in metadata: ${metadata.metadataStr}")
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
|
||||
}
|
||||
val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
|
||||
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
|
||||
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
|
||||
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
|
||||
val cls = Utils.classForName(stageMetadata.className)
|
||||
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
|
||||
DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
|
||||
}
|
||||
(metadata.uid, stages)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
|
|||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
||||
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
|
||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType
|
|||
@Since("1.2.0")
|
||||
@Experimental
|
||||
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||
extends Evaluator with HasRawPredictionCol with HasLabelCol {
|
||||
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
|
||||
|
||||
@Since("1.2.0")
|
||||
def this() = this(Identifiable.randomUID("binEval"))
|
||||
|
@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
|
|||
@Since("1.4.1")
|
||||
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] {
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): BinaryClassificationEvaluator = super.load(path)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
|
|||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
|
||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
||||
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
|
||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable}
|
||||
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||
import org.apache.spark.sql.{Row, DataFrame}
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
|
@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType
|
|||
@Since("1.5.0")
|
||||
@Experimental
|
||||
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol {
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
|
||||
|
||||
@Since("1.5.0")
|
||||
def this() = this(Identifiable.randomUID("mcEval"))
|
||||
|
@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
|||
@Since("1.5.0")
|
||||
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
object MulticlassClassificationEvaluator
|
||||
extends DefaultParamsReadable[MulticlassClassificationEvaluator] {
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
|
|||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
|
||||
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
|
||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
|
|||
@Since("1.4.0")
|
||||
@Experimental
|
||||
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol {
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
|
||||
|
||||
@Since("1.4.0")
|
||||
def this() = this(Identifiable.randomUID("regEval"))
|
||||
|
@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
|
|||
@Since("1.5.0")
|
||||
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): RegressionEvaluator = super.load(path)
|
||||
}
|
||||
|
|
|
@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64
|
|||
|
||||
import com.github.fommil.netlib.BLAS.{getInstance => blas}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.json4s.{DefaultFormats, JValue}
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.{Logging, Partitioner}
|
||||
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
|
||||
|
@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] {
|
|||
private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val extraMetadata = render("rank" -> instance.rank)
|
||||
val extraMetadata = "rank" -> instance.rank
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
|
||||
val userPath = new Path(path, "userFactors").toString
|
||||
instance.userFactors.write.format("parquet").save(userPath)
|
||||
|
@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] {
|
|||
override def load(path: String): ALSModel = {
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
implicit val format = DefaultFormats
|
||||
val rank: Int = metadata.extraMetadata match {
|
||||
case Some(m: JValue) =>
|
||||
(m \ "rank").extract[Int]
|
||||
case None =>
|
||||
throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
|
||||
s" ${metadata.metadataStr}")
|
||||
}
|
||||
|
||||
val rank = (metadata.metadata \ "rank").extract[Int]
|
||||
val userPath = new Path(path, "userFactors").toString
|
||||
val userFactors = sqlContext.read.format("parquet").load(userPath)
|
||||
val itemPath = new Path(path, "itemFactors").toString
|
||||
|
|
|
@ -18,17 +18,24 @@
|
|||
package org.apache.spark.ml.tuning
|
||||
|
||||
import com.github.fommil.netlib.F2jBLAS
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s.{JObject, DefaultFormats}
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.ml.classification.OneVsRestParams
|
||||
import org.apache.spark.ml.feature.RFormulaModel
|
||||
import org.apache.spark.{SparkContext, Logging}
|
||||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.evaluation.Evaluator
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
||||
/**
|
||||
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
|
||||
*/
|
||||
|
@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
|
|||
*/
|
||||
@Experimental
|
||||
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
|
||||
with CrossValidatorParams with Logging {
|
||||
with CrossValidatorParams with MLWritable with Logging {
|
||||
|
||||
def this() = this(Identifiable.randomUID("cv"))
|
||||
|
||||
|
@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
|
|||
}
|
||||
copied
|
||||
}
|
||||
|
||||
// Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types.
|
||||
// E.g., this may fail if a [[Param]] is an instance of an [[Estimator]].
|
||||
// However, this case should be unusual.
|
||||
@Since("1.6.0")
|
||||
override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
object CrossValidator extends MLReadable[CrossValidator] {
|
||||
|
||||
@Since("1.6.0")
|
||||
override def read: MLReader[CrossValidator] = new CrossValidatorReader
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): CrossValidator = super.load(path)
|
||||
|
||||
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
|
||||
|
||||
SharedReadWrite.validateParams(instance)
|
||||
|
||||
override protected def saveImpl(path: String): Unit =
|
||||
SharedReadWrite.saveImpl(path, instance, sc)
|
||||
}
|
||||
|
||||
private class CrossValidatorReader extends MLReader[CrossValidator] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = classOf[CrossValidator].getName
|
||||
|
||||
override def load(path: String): CrossValidator = {
|
||||
val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
|
||||
SharedReadWrite.load(path, sc, className)
|
||||
new CrossValidator(metadata.uid)
|
||||
.setEstimator(estimator)
|
||||
.setEvaluator(evaluator)
|
||||
.setEstimatorParamMaps(estimatorParamMaps)
|
||||
.setNumFolds(numFolds)
|
||||
}
|
||||
}
|
||||
|
||||
private object CrossValidatorReader {
|
||||
/**
|
||||
* Examine the given estimator (which may be a compound estimator) and extract a mapping
|
||||
* from UIDs to corresponding [[Params]] instances.
|
||||
*/
|
||||
def getUidMap(instance: Params): Map[String, Params] = {
|
||||
val uidList = getUidMapImpl(instance)
|
||||
val uidMap = uidList.toMap
|
||||
if (uidList.size != uidMap.size) {
|
||||
throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
|
||||
s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
|
||||
}
|
||||
uidMap
|
||||
}
|
||||
|
||||
def getUidMapImpl(instance: Params): List[(String, Params)] = {
|
||||
val subStages: Array[Params] = instance match {
|
||||
case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
|
||||
case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
|
||||
case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
|
||||
case ovr: OneVsRestParams =>
|
||||
// TODO: SPARK-11892: This case may require special handling.
|
||||
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
|
||||
" cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
|
||||
case rform: RFormulaModel =>
|
||||
// TODO: SPARK-11891: This case may require special handling.
|
||||
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
|
||||
" cannot yet handle an estimator containing an RFormulaModel")
|
||||
case _: Params => Array()
|
||||
}
|
||||
val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
|
||||
List((instance.uid, instance)) ++ subStageMaps
|
||||
}
|
||||
}
|
||||
|
||||
private[tuning] object SharedReadWrite {
|
||||
|
||||
/**
|
||||
* Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
|
||||
* This does not check [[CrossValidator.estimatorParamMaps]].
|
||||
*/
|
||||
def validateParams(instance: ValidatorParams): Unit = {
|
||||
def checkElement(elem: Params, name: String): Unit = elem match {
|
||||
case stage: MLWritable => // good
|
||||
case other =>
|
||||
throw new UnsupportedOperationException("CrossValidator write will fail " +
|
||||
s" because it contains $name which does not implement Writable." +
|
||||
s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
|
||||
}
|
||||
checkElement(instance.getEvaluator, "evaluator")
|
||||
checkElement(instance.getEstimator, "estimator")
|
||||
// Check to make sure all Params apply to this estimator. Throw an error if any do not.
|
||||
// Extraneous Params would cause problems when loading the estimatorParamMaps.
|
||||
val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
|
||||
instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
|
||||
pMap.toSeq.foreach { case ParamPair(p, v) =>
|
||||
require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
|
||||
s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
|
||||
s" Evaluator. An extraneous Param was found: $p")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[tuning] def saveImpl(
|
||||
path: String,
|
||||
instance: CrossValidatorParams,
|
||||
sc: SparkContext,
|
||||
extraMetadata: Option[JObject] = None): Unit = {
|
||||
import org.json4s.JsonDSL._
|
||||
|
||||
val estimatorParamMapsJson = compact(render(
|
||||
instance.getEstimatorParamMaps.map { case paramMap =>
|
||||
paramMap.toSeq.map { case ParamPair(p, v) =>
|
||||
Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
|
||||
}
|
||||
}.toSeq
|
||||
))
|
||||
val jsonParams = List(
|
||||
"numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
|
||||
"estimatorParamMaps" -> parse(estimatorParamMapsJson)
|
||||
)
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
|
||||
|
||||
val evaluatorPath = new Path(path, "evaluator").toString
|
||||
instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
|
||||
val estimatorPath = new Path(path, "estimator").toString
|
||||
instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
|
||||
}
|
||||
|
||||
private[tuning] def load[M <: Model[M]](
|
||||
path: String,
|
||||
sc: SparkContext,
|
||||
expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
|
||||
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
|
||||
|
||||
implicit val format = DefaultFormats
|
||||
val evaluatorPath = new Path(path, "evaluator").toString
|
||||
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
|
||||
val estimatorPath = new Path(path, "estimator").toString
|
||||
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
|
||||
|
||||
val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
|
||||
|
||||
val numFolds = (metadata.params \ "numFolds").extract[Int]
|
||||
val estimatorParamMaps: Array[ParamMap] =
|
||||
(metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
|
||||
pMap =>
|
||||
val paramPairs = pMap.map { case pInfo: Map[String, String] =>
|
||||
val est = uidToParams(pInfo("parent"))
|
||||
val param = est.getParam(pInfo("name"))
|
||||
val value = param.jsonDecode(pInfo("value"))
|
||||
param -> value
|
||||
}
|
||||
ParamMap(paramPairs: _*)
|
||||
}.toArray
|
||||
(metadata, estimator, evaluator, estimatorParamMaps, numFolds)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
|
|||
*
|
||||
* @param bestModel The best model selected from k-fold cross validation.
|
||||
* @param avgMetrics Average cross-validation metrics for each paramMap in
|
||||
* [[estimatorParamMaps]], in the corresponding order.
|
||||
* [[CrossValidator.estimatorParamMaps]], in the corresponding order.
|
||||
*/
|
||||
@Experimental
|
||||
class CrossValidatorModel private[ml] (
|
||||
override val uid: String,
|
||||
val bestModel: Model[_],
|
||||
val avgMetrics: Array[Double])
|
||||
extends Model[CrossValidatorModel] with CrossValidatorParams {
|
||||
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
|
||||
|
||||
override def validateParams(): Unit = {
|
||||
bestModel.validateParams()
|
||||
|
@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] (
|
|||
avgMetrics.clone())
|
||||
copyValues(copied, extra).setParent(parent)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
|
||||
|
||||
import CrossValidator.SharedReadWrite
|
||||
|
||||
@Since("1.6.0")
|
||||
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
|
||||
|
||||
@Since("1.6.0")
|
||||
override def load(path: String): CrossValidatorModel = super.load(path)
|
||||
|
||||
private[CrossValidatorModel]
|
||||
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
|
||||
|
||||
SharedReadWrite.validateParams(instance)
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
import org.json4s.JsonDSL._
|
||||
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
|
||||
SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
|
||||
val bestModelPath = new Path(path, "bestModel").toString
|
||||
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
|
||||
}
|
||||
}
|
||||
|
||||
private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = classOf[CrossValidatorModel].getName
|
||||
|
||||
override def load(path: String): CrossValidatorModel = {
|
||||
implicit val format = DefaultFormats
|
||||
|
||||
val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
|
||||
SharedReadWrite.load(path, sc, className)
|
||||
val bestModelPath = new Path(path, "bestModel").toString
|
||||
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
|
||||
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
|
||||
val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
|
||||
cv.set(cv.estimator, estimator)
|
||||
.set(cv.evaluator, evaluator)
|
||||
.set(cv.estimatorParamMaps, estimatorParamMaps)
|
||||
.set(cv.numFolds, numFolds)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter {
|
|||
* - timestamp
|
||||
* - sparkVersion
|
||||
* - uid
|
||||
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
|
||||
* - paramMap
|
||||
* - (optionally, extra metadata)
|
||||
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
|
||||
* @param paramMap If given, this is saved in the "paramMap" field.
|
||||
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
|
||||
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
|
||||
*/
|
||||
def saveMetadata(
|
||||
instance: Params,
|
||||
path: String,
|
||||
sc: SparkContext,
|
||||
extraMetadata: Option[JValue] = None): Unit = {
|
||||
extraMetadata: Option[JObject] = None,
|
||||
paramMap: Option[JValue] = None): Unit = {
|
||||
val uid = instance.uid
|
||||
val cls = instance.getClass.getName
|
||||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
||||
val jsonParams = params.map { case ParamPair(p, v) =>
|
||||
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
|
||||
p.name -> parse(p.jsonEncode(v))
|
||||
}.toList
|
||||
val metadata = ("class" -> cls) ~
|
||||
}.toList))
|
||||
val basicMetadata = ("class" -> cls) ~
|
||||
("timestamp" -> System.currentTimeMillis()) ~
|
||||
("sparkVersion" -> sc.version) ~
|
||||
("uid" -> uid) ~
|
||||
("paramMap" -> jsonParams) ~
|
||||
("extraMetadata" -> extraMetadata)
|
||||
("paramMap" -> jsonParams)
|
||||
val metadata = extraMetadata match {
|
||||
case Some(jObject) =>
|
||||
basicMetadata ~ jObject
|
||||
case None =>
|
||||
basicMetadata
|
||||
}
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataJson = compact(render(metadata))
|
||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||
|
@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader {
|
|||
/**
|
||||
* All info from metadata file.
|
||||
* @param params paramMap, as a [[JValue]]
|
||||
* @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
|
||||
* @param metadataStr Full metadata file String (for debugging)
|
||||
* @param metadata All metadata, including the other fields
|
||||
* @param metadataJson Full metadata file String (for debugging)
|
||||
*/
|
||||
case class Metadata(
|
||||
className: String,
|
||||
|
@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader {
|
|||
timestamp: Long,
|
||||
sparkVersion: String,
|
||||
params: JValue,
|
||||
extraMetadata: Option[JValue],
|
||||
metadataStr: String)
|
||||
metadata: JValue,
|
||||
metadataJson: String)
|
||||
|
||||
/**
|
||||
* Load metadata from file.
|
||||
|
@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader {
|
|||
val timestamp = (metadata \ "timestamp").extract[Long]
|
||||
val sparkVersion = (metadata \ "sparkVersion").extract[String]
|
||||
val params = metadata \ "paramMap"
|
||||
val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
|
||||
if (expectedClassName.nonEmpty) {
|
||||
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
|
||||
s" $expectedClassName but found class name $className")
|
||||
}
|
||||
|
||||
Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
|
||||
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader {
|
|||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a [[Params]] instance from the given path, and return it.
|
||||
* This assumes the instance implements [[MLReadable]].
|
||||
*/
|
||||
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc)
|
||||
val cls = Utils.classForName(metadata.className)
|
||||
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,11 +17,9 @@
|
|||
|
||||
package org.apache.spark.ml
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.mockito.Matchers.{any, eq => meq}
|
||||
import org.mockito.Mockito.when
|
||||
import org.scalatest.mock.MockitoSugar.mock
|
||||
|
|
|
@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
||||
class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
|
||||
class BinaryClassificationEvaluatorSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
val evaluator = new BinaryClassificationEvaluator()
|
||||
.setRawPredictionCol("myRawPrediction")
|
||||
.setLabelCol("myLabel")
|
||||
.setMetricName("areaUnderPR")
|
||||
testDefaultReadWrite(evaluator)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
||||
class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
|
||||
class MulticlassClassificationEvaluatorSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
val evaluator = new MulticlassClassificationEvaluator()
|
||||
.setPredictionCol("myPrediction")
|
||||
.setLabelCol("myLabel")
|
||||
.setMetricName("recall")
|
||||
testDefaultReadWrite(evaluator)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
|
||||
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
class RegressionEvaluatorSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new RegressionEvaluator)
|
||||
|
@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
evaluator.setMetricName("mae")
|
||||
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
val evaluator = new RegressionEvaluator()
|
||||
.setPredictionCol("myPrediction")
|
||||
.setLabelCol("myLabel")
|
||||
.setMetricName("r2")
|
||||
testDefaultReadWrite(evaluator)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,19 +18,22 @@
|
|||
package org.apache.spark.ml.tuning
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.util.MLTestingUtils
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.classification.LogisticRegression
|
||||
import org.apache.spark.ml.feature.HashingTF
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.{Pipeline, Estimator, Model}
|
||||
import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
|
||||
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.param.{ParamPair, ParamMap}
|
||||
import org.apache.spark.ml.param.shared.HasInputCol
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
class CrossValidatorSuite
|
||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
@transient var dataset: DataFrame = _
|
||||
|
||||
|
@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
|
||||
test("validateParams should check estimatorParamMaps") {
|
||||
import CrossValidatorSuite._
|
||||
import CrossValidatorSuite.{MyEstimator, MyEvaluator}
|
||||
|
||||
val est = new MyEstimator("est")
|
||||
val eval = new MyEvaluator
|
||||
|
@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
cv.validateParams()
|
||||
}
|
||||
}
|
||||
|
||||
test("read/write: CrossValidator with simple estimator") {
|
||||
val lr = new LogisticRegression().setMaxIter(3)
|
||||
val evaluator = new BinaryClassificationEvaluator()
|
||||
.setMetricName("areaUnderPR") // not default metric
|
||||
val paramMaps = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.2))
|
||||
.build()
|
||||
val cv = new CrossValidator()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(evaluator)
|
||||
.setNumFolds(20)
|
||||
.setEstimatorParamMaps(paramMaps)
|
||||
|
||||
val cv2 = testDefaultReadWrite(cv, testParams = false)
|
||||
|
||||
assert(cv.uid === cv2.uid)
|
||||
assert(cv.getNumFolds === cv2.getNumFolds)
|
||||
|
||||
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
|
||||
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
|
||||
assert(evaluator.uid === evaluator2.uid)
|
||||
assert(evaluator.getMetricName === evaluator2.getMetricName)
|
||||
|
||||
cv2.getEstimator match {
|
||||
case lr2: LogisticRegression =>
|
||||
assert(lr.uid === lr2.uid)
|
||||
assert(lr.getMaxIter === lr2.getMaxIter)
|
||||
case other =>
|
||||
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
|
||||
s" LogisticRegression but found ${other.getClass.getName}")
|
||||
}
|
||||
|
||||
CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
|
||||
}
|
||||
|
||||
test("read/write: CrossValidator with complex estimator") {
|
||||
// workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]]
|
||||
val lrEvaluator = new BinaryClassificationEvaluator()
|
||||
.setMetricName("areaUnderPR") // not default metric
|
||||
|
||||
val lr = new LogisticRegression().setMaxIter(3)
|
||||
val lrParamMaps = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.2))
|
||||
.build()
|
||||
val lrcv = new CrossValidator()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(lrEvaluator)
|
||||
.setEstimatorParamMaps(lrParamMaps)
|
||||
|
||||
val hashingTF = new HashingTF()
|
||||
val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv))
|
||||
val paramMaps = new ParamGridBuilder()
|
||||
.addGrid(hashingTF.numFeatures, Array(10, 20))
|
||||
.addGrid(lr.elasticNetParam, Array(0.0, 1.0))
|
||||
.build()
|
||||
val evaluator = new BinaryClassificationEvaluator()
|
||||
|
||||
val cv = new CrossValidator()
|
||||
.setEstimator(pipeline)
|
||||
.setEvaluator(evaluator)
|
||||
.setNumFolds(20)
|
||||
.setEstimatorParamMaps(paramMaps)
|
||||
|
||||
val cv2 = testDefaultReadWrite(cv, testParams = false)
|
||||
|
||||
assert(cv.uid === cv2.uid)
|
||||
assert(cv.getNumFolds === cv2.getNumFolds)
|
||||
|
||||
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
|
||||
assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
|
||||
|
||||
CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
|
||||
|
||||
cv2.getEstimator match {
|
||||
case pipeline2: Pipeline =>
|
||||
assert(pipeline.uid === pipeline2.uid)
|
||||
pipeline2.getStages match {
|
||||
case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) =>
|
||||
assert(hashingTF.uid === hashingTF2.uid)
|
||||
lrcv2.getEstimator match {
|
||||
case lr2: LogisticRegression =>
|
||||
assert(lr.uid === lr2.uid)
|
||||
assert(lr.getMaxIter === lr2.getMaxIter)
|
||||
case other =>
|
||||
throw new AssertionError(s"Loaded internal CrossValidator expected to be" +
|
||||
s" LogisticRegression but found type ${other.getClass.getName}")
|
||||
}
|
||||
assert(lrcv.uid === lrcv2.uid)
|
||||
assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
|
||||
assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
|
||||
CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
|
||||
case other =>
|
||||
throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
|
||||
" but found: " + other.map(_.getClass.getName).mkString(", "))
|
||||
}
|
||||
case other =>
|
||||
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
|
||||
s" CrossValidator but found ${other.getClass.getName}")
|
||||
}
|
||||
}
|
||||
|
||||
test("read/write: CrossValidator fails for extraneous Param") {
|
||||
val lr = new LogisticRegression()
|
||||
val lr2 = new LogisticRegression()
|
||||
val evaluator = new BinaryClassificationEvaluator()
|
||||
val paramMaps = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.2))
|
||||
.addGrid(lr2.regParam, Array(0.1, 0.2))
|
||||
.build()
|
||||
val cv = new CrossValidator()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(evaluator)
|
||||
.setEstimatorParamMaps(paramMaps)
|
||||
withClue("CrossValidator.write failed to catch extraneous Param error") {
|
||||
intercept[IllegalArgumentException] {
|
||||
cv.write
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("read/write: CrossValidatorModel") {
|
||||
val lr = new LogisticRegression()
|
||||
.setThreshold(0.6)
|
||||
val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
|
||||
.setThreshold(0.6)
|
||||
val evaluator = new BinaryClassificationEvaluator()
|
||||
.setMetricName("areaUnderPR") // not default metric
|
||||
val paramMaps = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.2))
|
||||
.build()
|
||||
val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6))
|
||||
cv.set(cv.estimator, lr)
|
||||
.set(cv.evaluator, evaluator)
|
||||
.set(cv.numFolds, 20)
|
||||
.set(cv.estimatorParamMaps, paramMaps)
|
||||
|
||||
val cv2 = testDefaultReadWrite(cv, testParams = false)
|
||||
|
||||
assert(cv.uid === cv2.uid)
|
||||
assert(cv.getNumFolds === cv2.getNumFolds)
|
||||
|
||||
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
|
||||
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
|
||||
assert(evaluator.uid === evaluator2.uid)
|
||||
assert(evaluator.getMetricName === evaluator2.getMetricName)
|
||||
|
||||
cv2.getEstimator match {
|
||||
case lr2: LogisticRegression =>
|
||||
assert(lr.uid === lr2.uid)
|
||||
assert(lr.getThreshold === lr2.getThreshold)
|
||||
case other =>
|
||||
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
|
||||
s" LogisticRegression but found ${other.getClass.getName}")
|
||||
}
|
||||
|
||||
CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
|
||||
|
||||
cv2.bestModel match {
|
||||
case lrModel2: LogisticRegressionModel =>
|
||||
assert(lrModel.uid === lrModel2.uid)
|
||||
assert(lrModel.getThreshold === lrModel2.getThreshold)
|
||||
assert(lrModel.coefficients === lrModel2.coefficients)
|
||||
assert(lrModel.intercept === lrModel2.intercept)
|
||||
case other =>
|
||||
throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
|
||||
s" LogisticRegressionModel but found ${other.getClass.getName}")
|
||||
}
|
||||
assert(cv.avgMetrics === cv2.avgMetrics)
|
||||
}
|
||||
}
|
||||
|
||||
object CrossValidatorSuite {
|
||||
object CrossValidatorSuite extends SparkFunSuite {
|
||||
|
||||
/**
|
||||
* Assert sequences of estimatorParamMaps are identical.
|
||||
* Params must be simple types comparable with `===`.
|
||||
*/
|
||||
def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
|
||||
assert(pMaps.length === pMaps2.length)
|
||||
pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
|
||||
assert(pMap.size === pMap2.size)
|
||||
pMap.toSeq.foreach { case ParamPair(p, v) =>
|
||||
assert(pMap2.contains(p))
|
||||
assert(pMap2(p) === v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract class MyModel extends Model[MyModel]
|
||||
|
||||
|
|
Loading…
Reference in a new issue