[SPARK-11618][ML] Minor refactoring of basic ML import/export
Refactoring * separated overwrite and param save logic in DefaultParamsWriter * added sparkVersion to DefaultParamsWriter CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9587 from jkbradley/logreg-io.
This commit is contained in:
parent
f14e95115c
commit
18350a5700
|
@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite {
|
|||
protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
|
||||
SQLContext.getOrCreate(SparkContext.getOrCreate())
|
||||
}
|
||||
|
||||
/** Returns the [[SparkContext]] underlying [[sqlContext]] */
|
||||
protected final def sc: SparkContext = sqlContext.sparkContext
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite {
|
|||
*/
|
||||
@Experimental
|
||||
@Since("1.6.0")
|
||||
abstract class Writer extends BaseReadWrite {
|
||||
abstract class Writer extends BaseReadWrite with Logging {
|
||||
|
||||
protected var shouldOverwrite: Boolean = false
|
||||
|
||||
|
@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite {
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
|
||||
def save(path: String): Unit
|
||||
def save(path: String): Unit = {
|
||||
val hadoopConf = sc.hadoopConfiguration
|
||||
val fs = FileSystem.get(hadoopConf)
|
||||
val p = new Path(path)
|
||||
if (fs.exists(p)) {
|
||||
if (shouldOverwrite) {
|
||||
logInfo(s"Path $path already exists. It will be overwritten.")
|
||||
// TODO: Revert back to the original content if save is not successful.
|
||||
fs.delete(p, true)
|
||||
} else {
|
||||
throw new IOException(
|
||||
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
|
||||
}
|
||||
}
|
||||
saveImpl(path)
|
||||
}
|
||||
|
||||
/**
|
||||
* [[save()]] handles overwriting and then calls this method. Subclasses should override this
|
||||
* method to implement the actual saving of the instance.
|
||||
*/
|
||||
@Since("1.6.0")
|
||||
protected def saveImpl(path: String): Unit
|
||||
|
||||
/**
|
||||
* Overwrites if the output path already exists.
|
||||
|
@ -147,28 +172,9 @@ trait Readable[T] {
|
|||
* data (e.g., models with coefficients).
|
||||
* @param instance object to save
|
||||
*/
|
||||
private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
|
||||
|
||||
/**
|
||||
* Saves the ML component to the input path.
|
||||
*/
|
||||
override def save(path: String): Unit = {
|
||||
val sc = sqlContext.sparkContext
|
||||
|
||||
val hadoopConf = sc.hadoopConfiguration
|
||||
val fs = FileSystem.get(hadoopConf)
|
||||
val p = new Path(path)
|
||||
if (fs.exists(p)) {
|
||||
if (shouldOverwrite) {
|
||||
logInfo(s"Path $path already exists. It will be overwritten.")
|
||||
// TODO: Revert back to the original content if save is not successful.
|
||||
fs.delete(p, true)
|
||||
} else {
|
||||
throw new IOException(
|
||||
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
|
||||
}
|
||||
}
|
||||
private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val uid = instance.uid
|
||||
val cls = instance.getClass.getName
|
||||
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
|
||||
|
@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
|
|||
}.toList
|
||||
val metadata = ("class" -> cls) ~
|
||||
("timestamp" -> System.currentTimeMillis()) ~
|
||||
("sparkVersion" -> sc.version) ~
|
||||
("uid" -> uid) ~
|
||||
("paramMap" -> jsonParams)
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
|
@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
|
|||
*/
|
||||
private[ml] class DefaultParamsReader[T] extends Reader[T] {
|
||||
|
||||
/**
|
||||
* Loads the ML component from the input path.
|
||||
*/
|
||||
override def load(path: String): T = {
|
||||
implicit val format = DefaultFormats
|
||||
val sc = sqlContext.sparkContext
|
||||
val metadataPath = new Path(path, "metadata").toString
|
||||
val metadataStr = sc.textFile(metadataPath, 1).first()
|
||||
val metadata = parse(metadataStr)
|
||||
|
|
Loading…
Reference in a new issue