[SPARK-11612][ML] Pipeline and PipelineModel persistence
Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable. Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9674 from jkbradley/pipeline-io.
This commit is contained in:
parent
bd10eb81c9
commit
1c5475f140
|
@ -22,12 +22,19 @@ import java.{util => ju}
|
|||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.{SparkContext, Logging}
|
||||
import org.apache.spark.annotation.{DeveloperApi, Experimental}
|
||||
import org.apache.spark.ml.param.{Param, ParamMap, Params}
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.ml.util.Reader
|
||||
import org.apache.spark.ml.util.Writer
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
|
@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
|
|||
* an identity transformer.
|
||||
*/
|
||||
@Experimental
|
||||
class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
|
||||
class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {
|
||||
|
||||
def this() = this(Identifiable.randomUID("pipeline"))
|
||||
|
||||
|
@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
|
|||
"Cannot have duplicate components in a pipeline.")
|
||||
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
|
||||
}
|
||||
|
||||
override def write: Writer = new Pipeline.PipelineWriter(this)
|
||||
}
|
||||
|
||||
object Pipeline extends Readable[Pipeline] {
|
||||
|
||||
override def read: Reader[Pipeline] = new PipelineReader
|
||||
|
||||
override def load(path: String): Pipeline = read.load(path)
|
||||
|
||||
private[ml] class PipelineWriter(instance: Pipeline) extends Writer {
|
||||
|
||||
SharedReadWrite.validateStages(instance.getStages)
|
||||
|
||||
override protected def saveImpl(path: String): Unit =
|
||||
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
|
||||
}
|
||||
|
||||
private[ml] class PipelineReader extends Reader[Pipeline] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = "org.apache.spark.ml.Pipeline"
|
||||
|
||||
override def load(path: String): Pipeline = {
|
||||
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
|
||||
new Pipeline(uid).setStages(stages)
|
||||
}
|
||||
}
|
||||
|
||||
/** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */
|
||||
private[ml] object SharedReadWrite {
|
||||
|
||||
import org.json4s.JsonDSL._
|
||||
|
||||
/** Check that all stages are Writable */
|
||||
def validateStages(stages: Array[PipelineStage]): Unit = {
|
||||
stages.foreach {
|
||||
case stage: Writable => // good
|
||||
case other =>
|
||||
throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
|
||||
s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
|
||||
s" ${other.uid} of type ${other.getClass}")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
|
||||
* - save metadata to path/metadata
|
||||
* - save stages to stages/IDX_UID
|
||||
*/
|
||||
def saveImpl(
|
||||
instance: Params,
|
||||
stages: Array[PipelineStage],
|
||||
sc: SparkContext,
|
||||
path: String): Unit = {
|
||||
// Copied and edited from DefaultParamsWriter.saveMetadata
|
||||
// TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
|
||||
val uid = instance.uid
|
||||
val cls = instance.getClass.getName
|
||||
val stageUids = stages.map(_.uid)
|
||||
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
|
||||
val metadata = ("class" -> cls) ~
|
||||
("timestamp" -> System.currentTimeMillis()) ~
|
||||
("sparkVersion" -> sc.version) ~
|
||||
("uid" -> uid) ~
|
||||
("paramMap" -> jsonParams)
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataJson = compact(render(metadata))
|
||||
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
|
||||
|
||||
// Save stages
|
||||
val stagesDir = new Path(path, "stages").toString
|
||||
stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
|
||||
stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
|
||||
* @return (UID, list of stages)
|
||||
*/
|
||||
def load(
|
||||
expectedClassName: String,
|
||||
sc: SparkContext,
|
||||
path: String): (String, Array[PipelineStage]) = {
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
|
||||
|
||||
implicit val format = DefaultFormats
|
||||
val stagesDir = new Path(path, "stages").toString
|
||||
val stageUids: Array[String] = metadata.params match {
|
||||
case JObject(pairs) =>
|
||||
if (pairs.length != 1) {
|
||||
// Should not happen unless file is corrupted or we have a bug.
|
||||
throw new RuntimeException(
|
||||
s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
|
||||
}
|
||||
pairs.head match {
|
||||
case ("stageUids", jsonValue) =>
|
||||
jsonValue.extract[Seq[String]].toArray
|
||||
case (paramName, jsonValue) =>
|
||||
// Should not happen unless file is corrupted or we have a bug.
|
||||
throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
|
||||
s" in metadata: ${metadata.metadataStr}")
|
||||
}
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
|
||||
}
|
||||
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
|
||||
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
|
||||
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
|
||||
val cls = Utils.classForName(stageMetadata.className)
|
||||
cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
|
||||
}
|
||||
(metadata.uid, stages)
|
||||
}
|
||||
|
||||
/** Get path for saving the given stage. */
|
||||
def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
|
||||
val stageIdxDigits = numStages.toString.length
|
||||
val idxFormat = s"%0${stageIdxDigits}d"
|
||||
val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
|
||||
new Path(stagesDir, stageDir).toString
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
|
|||
class PipelineModel private[ml] (
|
||||
override val uid: String,
|
||||
val stages: Array[Transformer])
|
||||
extends Model[PipelineModel] with Logging {
|
||||
extends Model[PipelineModel] with Writable with Logging {
|
||||
|
||||
/** A Java/Python-friendly auxiliary constructor. */
|
||||
private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
|
||||
|
@ -200,4 +332,39 @@ class PipelineModel private[ml] (
|
|||
override def copy(extra: ParamMap): PipelineModel = {
|
||||
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
|
||||
}
|
||||
|
||||
override def write: Writer = new PipelineModel.PipelineModelWriter(this)
|
||||
}
|
||||
|
||||
object PipelineModel extends Readable[PipelineModel] {
|
||||
|
||||
import Pipeline.SharedReadWrite
|
||||
|
||||
override def read: Reader[PipelineModel] = new PipelineModelReader
|
||||
|
||||
override def load(path: String): PipelineModel = read.load(path)
|
||||
|
||||
private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {
|
||||
|
||||
SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
|
||||
|
||||
override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
|
||||
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
|
||||
}
|
||||
|
||||
private[ml] class PipelineModelReader extends Reader[PipelineModel] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = "org.apache.spark.ml.PipelineModel"
|
||||
|
||||
override def load(path: String): PipelineModel = {
|
||||
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
|
||||
val transformers = stages map {
|
||||
case stage: Transformer => stage
|
||||
case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
|
||||
s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}")
|
||||
}
|
||||
new PipelineModel(uid, transformers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -164,6 +164,8 @@ trait Readable[T] {
|
|||
|
||||
/**
|
||||
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
|
||||
*
|
||||
* Note: Implementing classes should override this to be Java-friendly.
|
||||
*/
|
||||
@Since("1.6.0")
|
||||
def load(path: String): T = read.load(path)
|
||||
|
@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter {
|
|||
* - timestamp
|
||||
* - sparkVersion
|
||||
* - uid
|
||||
* - paramMap
|
||||
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
|
||||
*/
|
||||
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
|
||||
val uid = instance.uid
|
||||
|
|
|
@ -17,19 +17,25 @@
|
|||
|
||||
package org.apache.spark.ml
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.mockito.Matchers.{any, eq => meq}
|
||||
import org.mockito.Mockito.when
|
||||
import org.scalatest.mock.MockitoSugar.mock
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.Pipeline.SharedReadWrite
|
||||
import org.apache.spark.ml.feature.HashingTF
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.MLTestingUtils
|
||||
import org.apache.spark.ml.param.{IntParam, ParamMap}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class PipelineSuite extends SparkFunSuite {
|
||||
class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
abstract class MyModel extends Model[MyModel]
|
||||
|
||||
|
@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite {
|
|||
assert(pipelineModel1.uid === "pipeline1")
|
||||
assert(pipelineModel1.stages === stages)
|
||||
}
|
||||
|
||||
test("Pipeline read/write") {
|
||||
val writableStage = new WritableStage("writableStage").setIntParam(56)
|
||||
val pipeline = new Pipeline().setStages(Array(writableStage))
|
||||
|
||||
val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
|
||||
assert(pipeline2.getStages.length === 1)
|
||||
assert(pipeline2.getStages(0).isInstanceOf[WritableStage])
|
||||
val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage]
|
||||
assert(writableStage.getIntParam === writableStage2.getIntParam)
|
||||
}
|
||||
|
||||
test("Pipeline read/write with non-Writable stage") {
|
||||
val unWritableStage = new UnWritableStage("unwritableStage")
|
||||
val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage))
|
||||
withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") {
|
||||
intercept[UnsupportedOperationException] {
|
||||
unWritablePipeline.write
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("PipelineModel read/write") {
|
||||
val writableStage = new WritableStage("writableStage").setIntParam(56)
|
||||
val pipeline =
|
||||
new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer]))
|
||||
|
||||
val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
|
||||
assert(pipeline2.stages.length === 1)
|
||||
assert(pipeline2.stages(0).isInstanceOf[WritableStage])
|
||||
val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
|
||||
assert(writableStage.getIntParam === writableStage2.getIntParam)
|
||||
|
||||
val path = new File(tempDir, pipeline.uid).getPath
|
||||
val stagesDir = new Path(path, "stages").toString
|
||||
val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
|
||||
assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
|
||||
s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
|
||||
s" to be saved to path: $expectedStagePath")
|
||||
}
|
||||
|
||||
test("PipelineModel read/write: getStagePath") {
|
||||
val stageUid = "myStage"
|
||||
val stagesDir = new Path("pipeline", "stages").toString
|
||||
def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = {
|
||||
val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir)
|
||||
val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString
|
||||
assert(path === expected)
|
||||
}
|
||||
testStage(0, 1, "0")
|
||||
testStage(0, 9, "0")
|
||||
testStage(0, 10, "00")
|
||||
testStage(1, 10, "01")
|
||||
testStage(12, 999, "012")
|
||||
}
|
||||
|
||||
test("PipelineModel read/write with non-Writable stage") {
|
||||
val unWritableStage = new UnWritableStage("unwritableStage")
|
||||
val unWritablePipeline =
|
||||
new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer]))
|
||||
withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") {
|
||||
intercept[UnsupportedOperationException] {
|
||||
unWritablePipeline.write
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Used to test [[Pipeline]] with [[Writable]] stages */
|
||||
class WritableStage(override val uid: String) extends Transformer with Writable {
|
||||
|
||||
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
|
||||
|
||||
def getIntParam: Int = $(intParam)
|
||||
|
||||
def setIntParam(value: Int): this.type = set(intParam, value)
|
||||
|
||||
setDefault(intParam -> 0)
|
||||
|
||||
override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)
|
||||
|
||||
override def write: Writer = new DefaultParamsWriter(this)
|
||||
|
||||
override def transform(dataset: DataFrame): DataFrame = dataset
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = schema
|
||||
}
|
||||
|
||||
object WritableStage extends Readable[WritableStage] {
|
||||
|
||||
override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]
|
||||
|
||||
override def load(path: String): WritableStage = read.load(path)
|
||||
}
|
||||
|
||||
/** Used to test [[Pipeline]] with non-[[Writable]] stages */
|
||||
class UnWritableStage(override val uid: String) extends Transformer {
|
||||
|
||||
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
|
||||
|
||||
setDefault(intParam -> 0)
|
||||
|
||||
override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
|
||||
|
||||
override def transform(dataset: DataFrame): DataFrame = dataset
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = schema
|
||||
}
|
||||
|
|
|
@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
|
|||
/**
|
||||
* Checks "overwrite" option and params.
|
||||
* @param instance ML instance to test saving/loading
|
||||
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
|
||||
* @tparam T ML instance type
|
||||
* @return Instance loaded from file
|
||||
*/
|
||||
def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
|
||||
def testDefaultReadWrite[T <: Params with Writable](
|
||||
instance: T,
|
||||
testParams: Boolean = true): T = {
|
||||
val uid = instance.uid
|
||||
val path = new File(tempDir, uid).getPath
|
||||
|
||||
|
@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
|
|||
val newInstance = loader.load(path)
|
||||
|
||||
assert(newInstance.uid === instance.uid)
|
||||
instance.params.foreach { p =>
|
||||
if (instance.isDefined(p)) {
|
||||
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
|
||||
case (Array(values), Array(newValues)) =>
|
||||
assert(values === newValues, s"Values do not match on param ${p.name}.")
|
||||
case (value, newValue) =>
|
||||
assert(value === newValue, s"Values do not match on param ${p.name}.")
|
||||
if (testParams) {
|
||||
instance.params.foreach { p =>
|
||||
if (instance.isDefined(p)) {
|
||||
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
|
||||
case (Array(values), Array(newValues)) =>
|
||||
assert(values === newValues, s"Values do not match on param ${p.name}.")
|
||||
case (value, newValue) =>
|
||||
assert(value === newValue, s"Values do not match on param ${p.name}.")
|
||||
}
|
||||
} else {
|
||||
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
|
||||
}
|
||||
} else {
|
||||
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue