[SPARK-11618][ML] Minor refactoring of basic ML import/export

Refactoring
* separated overwrite and param save logic in DefaultParamsWriter
* added sparkVersion to DefaultParamsWriter

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9587 from jkbradley/logreg-io.
This commit is contained in:
Joseph K. Bradley 2015-11-10 11:36:43 -08:00 committed by Xiangrui Meng
parent f14e95115c
commit 18350a5700

View file

@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite {
protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
SQLContext.getOrCreate(SparkContext.getOrCreate())
}
/** Returns the [[SparkContext]] underlying [[sqlContext]] */
protected final def sc: SparkContext = sqlContext.sparkContext
}
/**
@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite {
*/
@Experimental
@Since("1.6.0")
abstract class Writer extends BaseReadWrite {
abstract class Writer extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite {
*/
@Since("1.6.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit
def save(path: String): Unit = {
val hadoopConf = sc.hadoopConfiguration
val fs = FileSystem.get(hadoopConf)
val p = new Path(path)
if (fs.exists(p)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(p, true)
} else {
throw new IOException(
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
}
}
saveImpl(path)
}
/**
* [[save()]] handles overwriting and then calls this method. Subclasses should override this
* method to implement the actual saving of the instance.
*/
@Since("1.6.0")
protected def saveImpl(path: String): Unit
/**
* Overwrites if the output path already exists.
@ -147,28 +172,9 @@ trait Readable[T] {
* data (e.g., models with coefficients).
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
/**
* Saves the ML component to the input path.
*/
override def save(path: String): Unit = {
val sc = sqlContext.sparkContext
val hadoopConf = sc.hadoopConfiguration
val fs = FileSystem.get(hadoopConf)
val p = new Path(path)
if (fs.exists(p)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(p, true)
} else {
throw new IOException(
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
}
}
private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
override protected def saveImpl(path: String): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
}.toList
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
*/
private[ml] class DefaultParamsReader[T] extends Reader[T] {
/**
* Loads the ML component from the input path.
*/
override def load(path: String): T = {
implicit val format = DefaultFormats
val sc = sqlContext.sparkContext
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)