[SPARK-25793][ML] call SaveLoadV2_0.load for classNameV2_0
## What changes were proposed in this pull request? The following code in BisectingKMeansModel.load calls the wrong version of load. ``` case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => val model = SaveLoadV1_0.load(sc, path) ``` Closes #22790 from huaxingao/spark-25793. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
72a23a6c43
commit
dc9b320807
|
@ -109,10 +109,10 @@ class BisectingKMeansModel private[clustering] (
|
|||
|
||||
@Since("2.0.0")
|
||||
override def save(sc: SparkContext, path: String): Unit = {
|
||||
BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path)
|
||||
BisectingKMeansModel.SaveLoadV2_0.save(sc, this, path)
|
||||
}
|
||||
|
||||
override protected def formatVersion: String = "1.0"
|
||||
override protected def formatVersion: String = "2.0"
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
@ -126,7 +126,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
|
|||
val model = SaveLoadV1_0.load(sc, path)
|
||||
model
|
||||
case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) =>
|
||||
val model = SaveLoadV1_0.load(sc, path)
|
||||
val model = SaveLoadV2_0.load(sc, path)
|
||||
model
|
||||
case _ => throw new Exception(
|
||||
s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
|
||||
|
|
|
@ -187,11 +187,12 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
val points = (1 until 8).map(i => Vectors.dense(i))
|
||||
val data = sc.parallelize(points, 2)
|
||||
val model = new BisectingKMeans().run(data)
|
||||
val model = new BisectingKMeans().setDistanceMeasure(DistanceMeasure.COSINE).run(data)
|
||||
try {
|
||||
model.save(sc, path)
|
||||
val sameModel = BisectingKMeansModel.load(sc, path)
|
||||
assert(model.k === sameModel.k)
|
||||
assert(model.distanceMeasure === sameModel.distanceMeasure)
|
||||
model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2)
|
||||
} finally {
|
||||
Utils.deleteRecursively(tempDir)
|
||||
|
|
Loading…
Reference in a new issue