[SPARK-11612][ML] Pipeline and PipelineModel persistence

Pipeline and PipelineModel extend Readable and Writable.  Persistence succeeds only when all stages are Writable.

Note: This PR reinstates tests for other read/write functionality.  It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9674 from jkbradley/pipeline-io.
This commit is contained in:
Joseph K. Bradley 2015-11-16 17:12:39 -08:00
parent bd10eb81c9
commit 1c5475f140
4 changed files with 306 additions and 18 deletions

View file

@ -22,12 +22,19 @@ import java.{util => ju}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Reader
import org.apache.spark.ml.util.Writer
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer. * an identity transformer.
*/ */
@Experimental @Experimental
class Pipeline(override val uid: String) extends Estimator[PipelineModel] { class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {
def this() = this(Identifiable.randomUID("pipeline")) def this() = this(Identifiable.randomUID("pipeline"))
@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
"Cannot have duplicate components in a pipeline.") "Cannot have duplicate components in a pipeline.")
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
} }
override def write: Writer = new Pipeline.PipelineWriter(this)
}
object Pipeline extends Readable[Pipeline] {
override def read: Reader[Pipeline] = new PipelineReader
override def load(path: String): Pipeline = read.load(path)
private[ml] class PipelineWriter(instance: Pipeline) extends Writer {
SharedReadWrite.validateStages(instance.getStages)
override protected def saveImpl(path: String): Unit =
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
}
private[ml] class PipelineReader extends Reader[Pipeline] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.Pipeline"
override def load(path: String): Pipeline = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
new Pipeline(uid).setStages(stages)
}
}
/** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */
private[ml] object SharedReadWrite {
import org.json4s.JsonDSL._
/** Check that all stages are Writable */
def validateStages(stages: Array[PipelineStage]): Unit = {
stages.foreach {
case stage: Writable => // good
case other =>
throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
s" ${other.uid} of type ${other.getClass}")
}
}
/**
* Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
* - save metadata to path/metadata
* - save stages to stages/IDX_UID
*/
def saveImpl(
instance: Params,
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit = {
// Copied and edited from DefaultParamsWriter.saveMetadata
// TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
val uid = instance.uid
val cls = instance.getClass.getName
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
// Save stages
val stagesDir = new Path(path, "stages").toString
stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
}
}
/**
* Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
* @return (UID, list of stages)
*/
def load(
expectedClassName: String,
sc: SparkContext,
path: String): (String, Array[PipelineStage]) = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
val stageUids: Array[String] = metadata.params match {
case JObject(pairs) =>
if (pairs.length != 1) {
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(
s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
}
pairs.head match {
case ("stageUids", jsonValue) =>
jsonValue.extract[Seq[String]].toArray
case (paramName, jsonValue) =>
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
s" in metadata: ${metadata.metadataStr}")
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
}
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
val cls = Utils.classForName(stageMetadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
}
(metadata.uid, stages)
}
/** Get path for saving the given stage. */
def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
val stageIdxDigits = numStages.toString.length
val idxFormat = s"%0${stageIdxDigits}d"
val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
new Path(stagesDir, stageDir).toString
}
}
} }
/** /**
@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
class PipelineModel private[ml] ( class PipelineModel private[ml] (
override val uid: String, override val uid: String,
val stages: Array[Transformer]) val stages: Array[Transformer])
extends Model[PipelineModel] with Logging { extends Model[PipelineModel] with Writable with Logging {
/** A Java/Python-friendly auxiliary constructor. */ /** A Java/Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, stages: ju.List[Transformer]) = { private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
@ -200,4 +332,39 @@ class PipelineModel private[ml] (
override def copy(extra: ParamMap): PipelineModel = { override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
} }
override def write: Writer = new PipelineModel.PipelineModelWriter(this)
}
object PipelineModel extends Readable[PipelineModel] {
import Pipeline.SharedReadWrite
override def read: Reader[PipelineModel] = new PipelineModelReader
override def load(path: String): PipelineModel = read.load(path)
private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {
SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
}
private[ml] class PipelineModelReader extends Reader[PipelineModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.PipelineModel"
override def load(path: String): PipelineModel = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
val transformers = stages map {
case stage: Transformer => stage
case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}")
}
new PipelineModel(uid, transformers)
}
}
} }

View file

@ -164,6 +164,8 @@ trait Readable[T] {
/** /**
* Reads an ML instance from the input path, a shortcut of `read.load(path)`. * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
*
* Note: Implementing classes should override this to be Java-friendly.
*/ */
@Since("1.6.0") @Since("1.6.0")
def load(path: String): T = read.load(path) def load(path: String): T = read.load(path)
@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter {
* - timestamp * - timestamp
* - sparkVersion * - sparkVersion
* - uid * - uid
* - paramMap * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/ */
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
val uid = instance.uid val uid = instance.uid

View file

@ -17,19 +17,25 @@
package org.apache.spark.ml package org.apache.spark.ml
import java.io.File
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.mockito.Matchers.{any, eq => meq} import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
class PipelineSuite extends SparkFunSuite { class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
abstract class MyModel extends Model[MyModel] abstract class MyModel extends Model[MyModel]
@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite {
assert(pipelineModel1.uid === "pipeline1") assert(pipelineModel1.uid === "pipeline1")
assert(pipelineModel1.stages === stages) assert(pipelineModel1.stages === stages)
} }
test("Pipeline read/write") {
val writableStage = new WritableStage("writableStage").setIntParam(56)
val pipeline = new Pipeline().setStages(Array(writableStage))
val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
assert(pipeline2.getStages.length === 1)
assert(pipeline2.getStages(0).isInstanceOf[WritableStage])
val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage]
assert(writableStage.getIntParam === writableStage2.getIntParam)
}
test("Pipeline read/write with non-Writable stage") {
val unWritableStage = new UnWritableStage("unwritableStage")
val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage))
withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") {
intercept[UnsupportedOperationException] {
unWritablePipeline.write
}
}
}
test("PipelineModel read/write") {
val writableStage = new WritableStage("writableStage").setIntParam(56)
val pipeline =
new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer]))
val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
assert(pipeline2.stages.length === 1)
assert(pipeline2.stages(0).isInstanceOf[WritableStage])
val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
assert(writableStage.getIntParam === writableStage2.getIntParam)
val path = new File(tempDir, pipeline.uid).getPath
val stagesDir = new Path(path, "stages").toString
val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
s" to be saved to path: $expectedStagePath")
}
test("PipelineModel read/write: getStagePath") {
val stageUid = "myStage"
val stagesDir = new Path("pipeline", "stages").toString
def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = {
val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir)
val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString
assert(path === expected)
}
testStage(0, 1, "0")
testStage(0, 9, "0")
testStage(0, 10, "00")
testStage(1, 10, "01")
testStage(12, 999, "012")
}
test("PipelineModel read/write with non-Writable stage") {
val unWritableStage = new UnWritableStage("unwritableStage")
val unWritablePipeline =
new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer]))
withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") {
intercept[UnsupportedOperationException] {
unWritablePipeline.write
}
}
}
}
/** Used to test [[Pipeline]] with [[Writable]] stages */
class WritableStage(override val uid: String) extends Transformer with Writable {
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
def getIntParam: Int = $(intParam)
def setIntParam(value: Int): this.type = set(intParam, value)
setDefault(intParam -> 0)
override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)
override def write: Writer = new DefaultParamsWriter(this)
override def transform(dataset: DataFrame): DataFrame = dataset
override def transformSchema(schema: StructType): StructType = schema
}
object WritableStage extends Readable[WritableStage] {
override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]
override def load(path: String): WritableStage = read.load(path)
}
/** Used to test [[Pipeline]] with non-[[Writable]] stages */
class UnWritableStage(override val uid: String) extends Transformer {
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
setDefault(intParam -> 0)
override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
override def transform(dataset: DataFrame): DataFrame = dataset
override def transformSchema(schema: StructType): StructType = schema
} }

View file

@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
/** /**
* Checks "overwrite" option and params. * Checks "overwrite" option and params.
* @param instance ML instance to test saving/loading * @param instance ML instance to test saving/loading
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type * @tparam T ML instance type
* @return Instance loaded from file * @return Instance loaded from file
*/ */
def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { def testDefaultReadWrite[T <: Params with Writable](
instance: T,
testParams: Boolean = true): T = {
val uid = instance.uid val uid = instance.uid
val path = new File(tempDir, uid).getPath val path = new File(tempDir, uid).getPath
@ -46,6 +49,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
val newInstance = loader.load(path) val newInstance = loader.load(path)
assert(newInstance.uid === instance.uid) assert(newInstance.uid === instance.uid)
if (testParams) {
instance.params.foreach { p => instance.params.foreach { p =>
if (instance.isDefined(p)) { if (instance.isDefined(p)) {
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match { (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
@ -58,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
} }
} }
}
val load = instance.getClass.getMethod("load", classOf[String]) val load = instance.getClass.getMethod("load", classOf[String])
val another = load.invoke(instance, path).asInstanceOf[T] val another = load.invoke(instance, path).asInstanceOf[T]