[SPARK-13334][ML] ML KMeansModel / BisectingKMeansModel / QuantileDiscretizer should set parent

ML ```KMeansModel / BisectingKMeansModel / QuantileDiscretizer``` should set parent.

cc mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #11214 from yanboliang/spark-13334.
This commit is contained in:
Yanbo Liang 2016-02-22 12:59:50 +02:00 committed by Nick Pentreath
parent e298ac91e3
commit 40e6d40fe7
6 changed files with 8 additions and 4 deletions

View file

@ -185,7 +185,7 @@ class BisectingKMeans @Since("2.0.0") (
.setSeed($(seed))
val parentModel = bkm.run(rdd)
val model = new BisectingKMeansModel(uid, parentModel)
copyValues(model)
copyValues(model.setParent(this))
}
@Since("2.0.0")

View file

@ -250,7 +250,7 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = new KMeansModel(uid, parentModel)
copyValues(model)
copyValues(model.setParent(this))
}
@Since("1.5.0")

View file

@ -95,7 +95,7 @@ final class QuantileDiscretizer(override val uid: String)
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates)
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer)
copyValues(bucketizer.setParent(this))
}
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)

View file

@ -81,5 +81,6 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
}

View file

@ -97,6 +97,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
test("read/write") {

View file

@ -94,7 +94,9 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
.setNumBuckets(numBucket).setSeed(1)
val result = discretizer.fit(df).transform(df)
val model = discretizer.fit(df)
assert(model.hasParent)
val result = model.transform(df)
val transformedFeatures = result.select("result").collect()
.map { case Row(transformedFeature: Double) => transformedFeature }