[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings
## What changes were proposed in this pull request? Several places in MLlib use custom regexes or other approaches to parse Spark versions. Those should be fixed to use the VersionUtils. This PR replaces custom regexes with VersionUtils to get Spark version numbers. ## How was this patch tested? Existing tests. Signed-off-by: VinceShieh vincent.xieintel.com Author: VinceShieh <vincent.xie@intel.com> Closes #15055 from VinceShieh/SPARK-17462.
This commit is contained in:
parent
49b6f456ac
commit
de77c67750
|
@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||
import org.apache.spark.util.VersionUtils.majorVersion
|
||||
|
||||
/**
|
||||
* Common params for KMeans and KMeansModel
|
||||
|
@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
|
|||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
val dataPath = new Path(path, "data").toString
|
||||
|
||||
val versionRegex = "([0-9]+)\\.(.+)".r
|
||||
val versionRegex(major, _) = metadata.sparkVersion
|
||||
|
||||
val clusterCenters = if (major.toInt >= 2) {
|
||||
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
|
||||
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
|
||||
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
|
||||
} else {
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
import org.apache.spark.util.VersionUtils.majorVersion
|
||||
|
||||
/**
|
||||
* Params for [[PCA]] and [[PCAModel]].
|
||||
|
@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
|
|||
override def load(path: String): PCAModel = {
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
|
||||
val versionRegex = "([0-9]+)\\.(.+)".r
|
||||
val versionRegex(major, _) = metadata.sparkVersion
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val model = if (major.toInt >= 2) {
|
||||
val model = if (majorVersion(metadata.sparkVersion) >= 2) {
|
||||
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
|
||||
sparkSession.read.parquet(dataPath)
|
||||
.select("pc", "explainedVariance")
|
||||
|
|
Loading…
Reference in a new issue