[SPARK-16245][ML] model loading backward compatibility for ml.feature.PCA

## What changes were proposed in this pull request?
model loading backward compatibility for ml.feature.PCA.

## How was this patch tested?
existing ut and manual test for loading models saved by Spark 1.6.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #13937 from yanboliang/spark-16245.
This commit is contained in:
Yanbo Liang 2016-06-28 19:53:07 -07:00 committed by Xiangrui Meng
parent 363bcedeea
commit 0df5ce1bc1

View file

@ -206,24 +206,22 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = { override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
// explainedVariance field is not present in Spark <= 1.6 val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex = "([0-9]+)\\.([0-9]+).*".r val versionRegex(major, _) = metadata.sparkVersion
val hasExplainedVariance = metadata.sparkVersion match {
case versionRegex(major, minor) =>
major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
case _ => false
}
val dataPath = new Path(path, "data").toString val dataPath = new Path(path, "data").toString
val model = if (hasExplainedVariance) { val model = if (major.toInt >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) = val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath) sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance") .select("pc", "explainedVariance")
.head() .head()
new PCAModel(metadata.uid, pc, explainedVariance) new PCAModel(metadata.uid, pc, explainedVariance)
} else { } else {
val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head() // pc field is the old matrix format in Spark <= 1.6
new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) // explainedVariance field is not present in Spark <= 1.6
val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
new PCAModel(metadata.uid, pc.asML,
Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
} }
DefaultParamsReader.getAndSetParams(model, metadata) DefaultParamsReader.getAndSetParams(model, metadata)
model model