[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings

## What changes were proposed in this pull request?

Several places in MLlib use custom regexes or other approaches to parse Spark versions.
Those should be fixed to use the VersionUtils. This PR replaces custom regexes with
VersionUtils to get Spark version numbers.
## How was this patch tested?

Existing tests.

Signed-off-by: VinceShieh vincent.xieintel.com

Author: VinceShieh <vincent.xie@intel.com>

Closes #15055 from VinceShieh/SPARK-17462.
This commit is contained in:
VinceShieh 2016-11-17 13:37:42 +00:00 committed by Sean Owen
parent 49b6f456ac
commit de77c67750
No known key found for this signature in database
GPG key ID: BEB3956D6717BDDC
2 changed files with 4 additions and 8 deletions

View file

@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.util.VersionUtils.majorVersion
/**
* Common params for KMeans and KMeansModel
@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val clusterCenters = if (major.toInt >= 2) {
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {

View file

@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.VersionUtils.majorVersion
/**
* Params for [[PCA]] and [[PCAModel]].
@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val dataPath = new Path(path, "data").toString
val model = if (major.toInt >= 2) {
val model = if (majorVersion(metadata.sparkVersion) >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")