[SPARK-6791][ML] Add read/write for CrossValidator and Evaluators

I believe this works for general estimators within CrossValidator, including compound estimators.  (See the complex unit test.)

Added read/write for all 3 Evaluators as well.

CC: mengxr yanboliang

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9848 from jkbradley/cv-io.
This commit is contained in:
Joseph K. Bradley 2015-11-22 21:48:48 -08:00 committed by Xiangrui Meng
parent fe89c1817d
commit a6fda0bfc1
12 changed files with 522 additions and 85 deletions

View file

@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] {
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit = {
// Copied and edited from DefaultParamsWriter.saveMetadata
// TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
val uid = instance.uid
val cls = instance.getClass.getName
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
// Save stages
val stagesDir = new Path(path, "stages").toString
@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] {
implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
val stageUids: Array[String] = metadata.params match {
case JObject(pairs) =>
if (pairs.length != 1) {
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(
s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
}
pairs.head match {
case ("stageUids", jsonValue) =>
jsonValue.extract[Seq[String]].toArray
case (paramName, jsonValue) =>
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
s" in metadata: ${metadata.metadataStr}")
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
}
val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
val cls = Utils.classForName(stageMetadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
}
(metadata.uid, stages)
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol {
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))
@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
@Since("1.6.0")
object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] {
@Since("1.6.0")
override def load(path: String): BinaryClassificationEvaluator = super.load(path)
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.types.DoubleType
@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol {
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
}
@Since("1.6.0")
object MulticlassClassificationEvaluator
extends DefaultParamsReadable[MulticlassClassificationEvaluator] {
@Since("1.6.0")
override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol {
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.5.0")
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
@Since("1.6.0")
object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {
@Since("1.6.0")
override def load(path: String): RegressionEvaluator = super.load(path)
}

View file

@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.{DefaultFormats, JValue}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] {
private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
val extraMetadata = "rank" -> instance.rank
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] {
override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
implicit val format = DefaultFormats
val rank: Int = metadata.extraMetadata match {
case Some(m: JValue) =>
(m \ "rank").extract[Int]
case None =>
throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
s" ${metadata.metadataStr}")
}
val rank = (metadata.metadata \ "rank").extract[Int]
val userPath = new Path(path, "userFactors").toString
val userFactors = sqlContext.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString

View file

@ -18,17 +18,24 @@
package org.apache.spark.ml.tuning
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.{JObject, DefaultFormats}
import org.json4s.jackson.JsonMethods._
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
*/
@Experimental
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
with CrossValidatorParams with Logging {
with CrossValidatorParams with MLWritable with Logging {
def this() = this(Identifiable.randomUID("cv"))
@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
}
copied
}
// Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types.
// E.g., this may fail if a [[Param]] is an instance of an [[Estimator]].
// However, this case should be unusual.
@Since("1.6.0")
override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this)
}
@Since("1.6.0")
object CrossValidator extends MLReadable[CrossValidator] {
@Since("1.6.0")
override def read: MLReader[CrossValidator] = new CrossValidatorReader
@Since("1.6.0")
override def load(path: String): CrossValidator = super.load(path)
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
SharedReadWrite.validateParams(instance)
override protected def saveImpl(path: String): Unit =
SharedReadWrite.saveImpl(path, instance, sc)
}
private class CrossValidatorReader extends MLReader[CrossValidator] {
/** Checked against metadata when loading model */
private val className = classOf[CrossValidator].getName
override def load(path: String): CrossValidator = {
val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
SharedReadWrite.load(path, sc, className)
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setNumFolds(numFolds)
}
}
private object CrossValidatorReader {
/**
* Examine the given estimator (which may be a compound estimator) and extract a mapping
* from UIDs to corresponding [[Params]] instances.
*/
def getUidMap(instance: Params): Map[String, Params] = {
val uidList = getUidMapImpl(instance)
val uidMap = uidList.toMap
if (uidList.size != uidMap.size) {
throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
}
uidMap
}
def getUidMapImpl(instance: Params): List[(String, Params)] = {
val subStages: Array[Params] = instance match {
case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
case ovr: OneVsRestParams =>
// TODO: SPARK-11892: This case may require special handling.
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
" cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
case rform: RFormulaModel =>
// TODO: SPARK-11891: This case may require special handling.
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
" cannot yet handle an estimator containing an RFormulaModel")
case _: Params => Array()
}
val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
List((instance.uid, instance)) ++ subStageMaps
}
}
private[tuning] object SharedReadWrite {
/**
* Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
* This does not check [[CrossValidator.estimatorParamMaps]].
*/
def validateParams(instance: ValidatorParams): Unit = {
def checkElement(elem: Params, name: String): Unit = elem match {
case stage: MLWritable => // good
case other =>
throw new UnsupportedOperationException("CrossValidator write will fail " +
s" because it contains $name which does not implement Writable." +
s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
}
checkElement(instance.getEvaluator, "evaluator")
checkElement(instance.getEstimator, "estimator")
// Check to make sure all Params apply to this estimator. Throw an error if any do not.
// Extraneous Params would cause problems when loading the estimatorParamMaps.
val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
pMap.toSeq.foreach { case ParamPair(p, v) =>
require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
s" Evaluator. An extraneous Param was found: $p")
}
}
}
private[tuning] def saveImpl(
path: String,
instance: CrossValidatorParams,
sc: SparkContext,
extraMetadata: Option[JObject] = None): Unit = {
import org.json4s.JsonDSL._
val estimatorParamMapsJson = compact(render(
instance.getEstimatorParamMaps.map { case paramMap =>
paramMap.toSeq.map { case ParamPair(p, v) =>
Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
}
}.toSeq
))
val jsonParams = List(
"numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
"estimatorParamMaps" -> parse(estimatorParamMapsJson)
)
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
val evaluatorPath = new Path(path, "evaluator").toString
instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
val estimatorPath = new Path(path, "estimator").toString
instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
}
private[tuning] def load[M <: Model[M]](
path: String,
sc: SparkContext,
expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
implicit val format = DefaultFormats
val evaluatorPath = new Path(path, "evaluator").toString
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
val estimatorPath = new Path(path, "estimator").toString
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val estimatorParamMaps: Array[ParamMap] =
(metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
pMap =>
val paramPairs = pMap.map { case pInfo: Map[String, String] =>
val est = uidToParams(pInfo("parent"))
val param = est.getParam(pInfo("name"))
val value = param.jsonDecode(pInfo("value"))
param -> value
}
ParamMap(paramPairs: _*)
}.toArray
(metadata, estimator, evaluator, estimatorParamMaps, numFolds)
}
}
}
/**
@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
*
* @param bestModel The best model selected from k-fold cross validation.
* @param avgMetrics Average cross-validation metrics for each paramMap in
* [[estimatorParamMaps]], in the corresponding order.
* [[CrossValidator.estimatorParamMaps]], in the corresponding order.
*/
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
val bestModel: Model[_],
val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
override def validateParams(): Unit = {
bestModel.validateParams()
@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] (
avgMetrics.clone())
copyValues(copied, extra).setParent(parent)
}
@Since("1.6.0")
override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)
}
@Since("1.6.0")
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
import CrossValidator.SharedReadWrite
@Since("1.6.0")
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
@Since("1.6.0")
override def load(path: String): CrossValidatorModel = super.load(path)
private[CrossValidatorModel]
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
SharedReadWrite.validateParams(instance)
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
}
}
private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] {
/** Checked against metadata when loading model */
private val className = classOf[CrossValidatorModel].getName
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats
val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
SharedReadWrite.load(path, sc, className)
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
cv.set(cv.estimator, estimator)
.set(cv.evaluator, evaluator)
.set(cv.estimatorParamMaps, estimatorParamMaps)
.set(cv.numFolds, numFolds)
}
}
}

View file

@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
* - paramMap
* - (optionally, extra metadata)
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JValue] = None): Unit = {
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = params.map { case ParamPair(p, v) =>
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList
val metadata = ("class" -> cls) ~
}.toList))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams) ~
("extraMetadata" -> extraMetadata)
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
basicMetadata
}
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
* @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
* @param metadataStr Full metadata file String (for debugging)
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
extraMetadata: Option[JValue],
metadataStr: String)
metadata: JValue,
metadataJson: String)
/**
* Load metadata from file.
@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
}
/**
@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader {
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
/**
* Load a [[Params]] instance from the given path, and return it.
* This assumes the instance implements [[MLReadable]].
*/
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}

View file

@ -17,11 +17,9 @@
package org.apache.spark.ml
import java.io.File
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.Path
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock

View file

@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
class BinaryClassificationEvaluatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
}
test("read/write") {
val evaluator = new BinaryClassificationEvaluator()
.setRawPredictionCol("myRawPrediction")
.setLabelCol("myLabel")
.setMetricName("areaUnderPR")
testDefaultReadWrite(evaluator)
}
}

View file

@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
class MulticlassClassificationEvaluatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
}
test("read/write") {
val evaluator = new MulticlassClassificationEvaluator()
.setPredictionCol("myPrediction")
.setLabelCol("myLabel")
.setMetricName("recall")
testDefaultReadWrite(evaluator)
}
}

View file

@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
class RegressionEvaluatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RegressionEvaluator)
@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
}
test("read/write") {
val evaluator = new RegressionEvaluator()
.setPredictionCol("myPrediction")
.setLabelCol("myLabel")
.setMetricName("r2")
testDefaultReadWrite(evaluator)
}
}

View file

@ -18,19 +18,22 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.{Pipeline, Estimator, Model}
import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{ParamPair, ParamMap}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("validateParams should check estimatorParamMaps") {
import CrossValidatorSuite._
import CrossValidatorSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est")
val eval = new MyEvaluator
@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
cv.validateParams()
}
}
test("read/write: CrossValidator with simple estimator") {
val lr = new LogisticRegression().setMaxIter(3)
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR") // not default metric
val paramMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.2))
.build()
val cv = new CrossValidator()
.setEstimator(lr)
.setEvaluator(evaluator)
.setNumFolds(20)
.setEstimatorParamMaps(paramMaps)
val cv2 = testDefaultReadWrite(cv, testParams = false)
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
assert(evaluator.uid === evaluator2.uid)
assert(evaluator.getMetricName === evaluator2.getMetricName)
cv2.getEstimator match {
case lr2: LogisticRegression =>
assert(lr.uid === lr2.uid)
assert(lr.getMaxIter === lr2.getMaxIter)
case other =>
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
s" LogisticRegression but found ${other.getClass.getName}")
}
CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
}
test("read/write: CrossValidator with complex estimator") {
// workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]]
val lrEvaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR") // not default metric
val lr = new LogisticRegression().setMaxIter(3)
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.2))
.build()
val lrcv = new CrossValidator()
.setEstimator(lr)
.setEvaluator(lrEvaluator)
.setEstimatorParamMaps(lrParamMaps)
val hashingTF = new HashingTF()
val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv))
val paramMaps = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10, 20))
.addGrid(lr.elasticNetParam, Array(0.0, 1.0))
.build()
val evaluator = new BinaryClassificationEvaluator()
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setNumFolds(20)
.setEstimatorParamMaps(paramMaps)
val cv2 = testDefaultReadWrite(cv, testParams = false)
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
cv2.getEstimator match {
case pipeline2: Pipeline =>
assert(pipeline.uid === pipeline2.uid)
pipeline2.getStages match {
case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) =>
assert(hashingTF.uid === hashingTF2.uid)
lrcv2.getEstimator match {
case lr2: LogisticRegression =>
assert(lr.uid === lr2.uid)
assert(lr.getMaxIter === lr2.getMaxIter)
case other =>
throw new AssertionError(s"Loaded internal CrossValidator expected to be" +
s" LogisticRegression but found type ${other.getClass.getName}")
}
assert(lrcv.uid === lrcv2.uid)
assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
case other =>
throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
" but found: " + other.map(_.getClass.getName).mkString(", "))
}
case other =>
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
s" CrossValidator but found ${other.getClass.getName}")
}
}
test("read/write: CrossValidator fails for extraneous Param") {
val lr = new LogisticRegression()
val lr2 = new LogisticRegression()
val evaluator = new BinaryClassificationEvaluator()
val paramMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.2))
.addGrid(lr2.regParam, Array(0.1, 0.2))
.build()
val cv = new CrossValidator()
.setEstimator(lr)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramMaps)
withClue("CrossValidator.write failed to catch extraneous Param error") {
intercept[IllegalArgumentException] {
cv.write
}
}
}
test("read/write: CrossValidatorModel") {
val lr = new LogisticRegression()
.setThreshold(0.6)
val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
.setThreshold(0.6)
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR") // not default metric
val paramMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.2))
.build()
val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6))
cv.set(cv.estimator, lr)
.set(cv.evaluator, evaluator)
.set(cv.numFolds, 20)
.set(cv.estimatorParamMaps, paramMaps)
val cv2 = testDefaultReadWrite(cv, testParams = false)
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
assert(evaluator.uid === evaluator2.uid)
assert(evaluator.getMetricName === evaluator2.getMetricName)
cv2.getEstimator match {
case lr2: LogisticRegression =>
assert(lr.uid === lr2.uid)
assert(lr.getThreshold === lr2.getThreshold)
case other =>
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
s" LogisticRegression but found ${other.getClass.getName}")
}
CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
cv2.bestModel match {
case lrModel2: LogisticRegressionModel =>
assert(lrModel.uid === lrModel2.uid)
assert(lrModel.getThreshold === lrModel2.getThreshold)
assert(lrModel.coefficients === lrModel2.coefficients)
assert(lrModel.intercept === lrModel2.intercept)
case other =>
throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
s" LogisticRegressionModel but found ${other.getClass.getName}")
}
assert(cv.avgMetrics === cv2.avgMetrics)
}
}
object CrossValidatorSuite {
object CrossValidatorSuite extends SparkFunSuite {
/**
* Assert sequences of estimatorParamMaps are identical.
* Params must be simple types comparable with `===`.
*/
def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
assert(pMaps.length === pMaps2.length)
pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
assert(pMap.size === pMap2.size)
pMap.toSeq.foreach { case ParamPair(p, v) =>
assert(pMap2.contains(p))
assert(pMap2(p) === v)
}
}
}
abstract class MyModel extends Model[MyModel]