[SPARK-29967][ML][PYTHON] KMeans support instance weighting

### What changes were proposed in this pull request?
add weight support in KMeans
### Why are the changes needed?
KMeans should support weighting
### Does this PR introduce any user-facing change?
Yes. ```KMeans.setWeightCol```

### How was this patch tested?
Unit Tests

Closes #26739 from huaxingao/spark-29967.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Huaxin Gao 2019-12-10 09:33:06 -06:00 committed by Sean Owen
parent aa9da9365f
commit 1cac9b2cc6
6 changed files with 332 additions and 45 deletions

View file

@ -31,9 +31,10 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion
@ -41,7 +42,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure with HasWeightCol {
/**
* The number of clusters to create (k). Must be &gt; 1. Note that it is possible for fewer than
@ -319,12 +320,31 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
} else {
lit(1.0)
}
val instances: RDD[(OldVector, Double)] = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w).rdd.map {
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
}
if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
@ -333,7 +353,7 @@ class KMeans @Since("1.5.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
maxIter, seed, tol, weightCol)
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
@ -342,7 +362,7 @@ class KMeans @Since("1.5.0") (
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.run(instances, Option(instr))
val parentModel = algo.runWithWeight(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset),

View file

@ -123,6 +123,10 @@ private[spark] class Instrumentation private () extends Logging with MLEvents {
logNamedValue(Instrumentation.loggerTags.numExamples, num)
}
def logSumOfWeights(num: Double): Unit = {
logNamedValue(Instrumentation.loggerTags.sumOfWeights, num)
}
/**
* Logs the value with customized name field.
*/
@ -179,6 +183,7 @@ private[spark] object Instrumentation {
val numExamples = "numExamples"
val meanOfLabels = "meanOfLabels"
val varianceOfLabels = "varianceOfLabels"
val sumOfWeights = "sumOfWeights"
}
def instrumented[T](body: (Instrumentation => T)): T = {

View file

@ -84,8 +84,8 @@ private[spark] abstract class DistanceMeasure extends Serializable {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0, point.vector, sum)
def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
axpy(weight, point.vector, sum)
}
/**
@ -100,6 +100,18 @@ private[spark] abstract class DistanceMeasure extends Serializable {
new VectorWithNorm(sum)
}
/**
* Returns a centroid for a cluster given its `sum` vector and the weightSum of points.
*
* @param sum the `sum` for a cluster
* @param weightSum the weightSum of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, weightSum: Double): VectorWithNorm = {
scal(1.0 / weightSum, sum)
new VectorWithNorm(sum)
}
/**
* Returns two new centroids symmetric to the specified centroid applying `noise` with the
* with the specified `level`.
@ -249,9 +261,9 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
override def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
axpy(1.0 / point.norm, point.vector, sum)
axpy(weight / point.norm, point.vector, sum)
}
/**

View file

@ -23,7 +23,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@ -209,11 +209,14 @@ class KMeans private (
*/
@Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = {
run(data, None)
val instances: RDD[(Vector, Double)] = data.map {
case (point) => (point, 1.0)
}
runWithWeight(instances, None)
}
private[spark] def run(
data: RDD[Vector],
private[spark] def runWithWeight(
data: RDD[(Vector, Double)],
instr: Option[Instrumentation]): KMeansModel = {
if (data.getStorageLevel == StorageLevel.NONE) {
@ -222,12 +225,15 @@ class KMeans private (
}
// Compute squared norms and cache them.
val norms = data.map(Vectors.norm(_, 2.0))
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
val norms = data.map { case (v, _) =>
Vectors.norm(v, 2.0)
}
val zippedData = data.zip(norms).map { case ((v, w), norm) =>
(new VectorWithNorm(v, norm), w)
}
zippedData.persist()
val model = runAlgorithm(zippedData, instr)
val model = runAlgorithmWithWeight(zippedData, instr)
zippedData.unpersist()
// Warn at the end of the run as well, for increased visibility.
@ -241,8 +247,8 @@ class KMeans private (
/**
* Implementation of K-Means algorithm.
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
private def runAlgorithmWithWeight(
data: RDD[(VectorWithNorm, Double)],
instr: Option[Instrumentation]): KMeansModel = {
val sc = data.sparkContext
@ -251,14 +257,17 @@ class KMeans private (
val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
val dataVectorWithNorm = data.map(d => d._1)
val weights = data.map(d => d._2)
val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
initRandom(dataVectorWithNorm)
} else {
initKMeansParallel(data, distanceMeasureInstance)
initKMeansParallel(dataVectorWithNorm, distanceMeasureInstance)
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
@ -275,35 +284,43 @@ class KMeans private (
// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) {
val costAccum = sc.doubleAccumulator
val countAccum = sc.longAccumulator
val bcCenters = sc.broadcast(centers)
// Find the new centers
val collected = data.mapPartitions { points =>
val collected = data.mapPartitions { pointsAndWeights =>
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
val counts = Array.fill(thisCenters.length)(0L)
points.foreach { point =>
// clusterWeightSum is needed to calculate cluster center
// cluster center =
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
val clusterWeightSum = Array.ofDim[Double](thisCenters.length)
pointsAndWeights.foreach { case (point, weight) =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
costAccum.add(cost * weight)
countAccum.add(1)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter), weight)
clusterWeightSum(bestCenter) += weight
}
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0)
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator
}.reduceByKey { case ((sum1, clusterWeightSum1), (sum2, clusterWeightSum2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
(sum1, clusterWeightSum1 + clusterWeightSum2)
}.collectAsMap()
if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
instr.foreach(_.logNumExamples(countAccum.value))
instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum))
}
val newCenters = collected.mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
val newCenters = collected.mapValues { case (sum, weightSum) =>
distanceMeasureInstance.centroid(sum, weightSum)
}
bcCenters.destroy()

View file

@ -217,7 +217,6 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
assert(trueCost ~== floatArrayCost absTol 1e-6)
}
test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
@ -254,6 +253,231 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
testClusteringModelSinglePrediction(model, model.predict, dataset,
model.getFeaturesCol, model.getPredictionCol)
}
test("compare with weightCol and without weightCol") {
val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Array(
Vectors.dense(1.0, 1.0),
Vectors.dense(10.0, 10.0), Vectors.dense(10.0, 10.0),
Vectors.dense(1.0, 0.5),
Vectors.dense(10.0, 4.4), Vectors.dense(10.0, 4.4),
Vectors.dense(-1.0, 1.0),
Vectors.dense(-100.0, 90.0), Vectors.dense(-100.0, 90.0)
)).map(v => TestRow(v)))
val model1 = new KMeans()
.setK(3)
.setSeed(42)
.setInitMode(MLlibKMeans.RANDOM)
.setTol(1e-6)
.setDistanceMeasure(DistanceMeasure.COSINE)
.fit(df1)
val predictionDf1 = model1.transform(df1)
assert(predictionDf1.select("prediction").distinct().count() == 3)
val predictionsMap1 = predictionDf1.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
assert(predictionsMap1(Vectors.dense(1.0, 1.0)) ==
predictionsMap1(Vectors.dense(10.0, 10.0)))
assert(predictionsMap1(Vectors.dense(1.0, 0.5)) ==
predictionsMap1(Vectors.dense(10.0, 4.4)))
assert(predictionsMap1(Vectors.dense(-1.0, 1.0)) ==
predictionsMap1(Vectors.dense(-100.0, 90.0)))
model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array(
(Vectors.dense(1.0, 1.0), 1.0),
(Vectors.dense(10.0, 10.0), 2.0),
(Vectors.dense(1.0, 0.5), 1.0),
(Vectors.dense(10.0, 4.4), 2.0),
(Vectors.dense(-1.0, 1.0), 1.0),
(Vectors.dense(-100.0, 90.0), 2.0)))).toDF("features", "weightCol")
val model2 = new KMeans()
.setK(3)
.setSeed(42)
.setInitMode(MLlibKMeans.RANDOM)
.setTol(1e-6)
.setDistanceMeasure(DistanceMeasure.COSINE)
.setWeightCol("weightCol")
.fit(df2)
val predictionDf2 = model2.transform(df2)
assert(predictionDf2.select("prediction").distinct().count() == 3)
val predictionsMap2 = predictionDf2.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
assert(predictionsMap2(Vectors.dense(1.0, 1.0)) ==
predictionsMap2(Vectors.dense(10.0, 10.0)))
assert(predictionsMap2(Vectors.dense(1.0, 0.5)) ==
predictionsMap2(Vectors.dense(10.0, 4.4)))
assert(predictionsMap2(Vectors.dense(-1.0, 1.0)) ==
predictionsMap2(Vectors.dense(-100.0, 90.0)))
model2.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
// compare if model1 and model2 have the same cluster centers
assert(model1.clusterCenters.length === model2.clusterCenters.length)
assert(model1.clusterCenters.toSet.subsetOf((model2.clusterCenters.toSet)))
}
test("Two centers with weightCol") {
// use the same weight for all samples.
val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Array(
(Vectors.dense(0.0, 0.0), 2.0),
(Vectors.dense(0.0, 0.1), 2.0),
(Vectors.dense(0.1, 0.0), 2.0),
(Vectors.dense(9.0, 0.0), 2.0),
(Vectors.dense(9.0, 0.2), 2.0),
(Vectors.dense(9.2, 0.0), 2.0)))).toDF("features", "weightCol")
val model1 = new KMeans()
.setK(2)
.setInitMode(MLlibKMeans.RANDOM)
.setWeightCol("weightCol")
.setMaxIter(10)
.fit(df1)
val predictionDf1 = model1.transform(df1)
assert(predictionDf1.select("prediction").distinct().count() == 2)
val predictionsMap1 = predictionDf1.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
assert(predictionsMap1(Vectors.dense(0.0, 0.0)) ==
predictionsMap1(Vectors.dense(0.0, 0.1)))
assert(predictionsMap1(Vectors.dense(0.0, 0.0)) ==
predictionsMap1(Vectors.dense(0.1, 0.0)))
assert(predictionsMap1(Vectors.dense(9.0, 0.0)) ==
predictionsMap1(Vectors.dense(9.0, 0.2)))
assert(predictionsMap1(Vectors.dense(9.0, 0.2)) ==
predictionsMap1(Vectors.dense(9.2, 0.0)))
model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
// center 1:
// total weights in cluster 1: 2.0 + 2.0 + 2.0 = 6.0
// x: 9.0 * (2.0/6.0) + 9.0 * (2.0/6.0) + 9.2 * (2.0/6.0) = 9.066666666666666
// y: 0.0 * (2.0/6.0) + 0.2 * (2.0/6.0) + 0.0 * (2.0/6.0) = 0.06666666666666667
// center 2:
// total weights in cluster 2: 2.0 + 2.0 + 2.0 = 6.0
// x: 0.0 * (2.0/6.0) + 0.0 * (2.0/6.0) + 0.1 * (2.0/6.0) = 0.03333333333333333
// y: 0.0 * (2.0/6.0) + 0.1 * (2.0/6.0) + 0.0 * (2.0/6.0) = 0.03333333333333333
val model1_center1 = Vectors.dense(9.066666666666666, 0.06666666666666667)
val model1_center2 = Vectors.dense(0.03333333333333333, 0.03333333333333333)
assert(model1.clusterCenters(0) === model1_center1)
assert(model1.clusterCenters(1) === model1_center2)
// use different weight
val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array(
(Vectors.dense(0.0, 0.0), 1.0),
(Vectors.dense(0.0, 0.1), 2.0),
(Vectors.dense(0.1, 0.0), 3.0),
(Vectors.dense(9.0, 0.0), 2.5),
(Vectors.dense(9.0, 0.2), 1.0),
(Vectors.dense(9.2, 0.0), 2.0)))).toDF("features", "weightCol")
val model2 = new KMeans()
.setK(2)
.setInitMode(MLlibKMeans.RANDOM)
.setWeightCol("weightCol")
.setMaxIter(10)
.fit(df2)
val predictionDf2 = model2.transform(df2)
assert(predictionDf2.select("prediction").distinct().count() == 2)
val predictionsMap2 = predictionDf2.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
assert(predictionsMap2(Vectors.dense(0.0, 0.0)) ==
predictionsMap2(Vectors.dense(0.0, 0.1)))
assert(predictionsMap2(Vectors.dense(0.0, 0.0)) ==
predictionsMap2(Vectors.dense(0.1, 0.0)))
assert(predictionsMap2(Vectors.dense(9.0, 0.0)) ==
predictionsMap2(Vectors.dense(9.0, 0.2)))
assert(predictionsMap2(Vectors.dense(9.0, 0.2)) ==
predictionsMap2(Vectors.dense(9.2, 0.0)))
model2.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
// center 1:
// total weights in cluster 1: 2.5 + 1.0 + 2.0 = 5.5
// x: 9.0 * (2.5/5.5) + 9.0 * (1.0/5.5) + 9.2 * (2.0/5.5) = 9.072727272727272
// y: 0.0 * (2.5/5.5) + 0.2 * (1.0/5.5) + 0.0 * (2.0/5.5) = 0.03636363636363637
// center 2:
// total weights in cluster 2: 1.0 + 2.0 + 3.0 = 6.0
// x: 0.0 * (1.0/6.0) + 0.0 * (2.0/6.0) + 0.1 * (3.0/6.0) = 0.05
// y: 0.0 * (1.0/6.0) + 0.1 * (2.0/6.0) + 0.0 * (3.0/6.0) = 0.03333333333333333
val model2_center1 = Vectors.dense(9.072727272727272, 0.03636363636363637)
val model2_center2 = Vectors.dense(0.05, 0.03333333333333333)
assert(model2.clusterCenters(0) === model2_center1)
assert(model2.clusterCenters(1) === model2_center2)
}
test("Four centers with weightCol") {
// no weight
val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Array(
Vectors.dense(0.1, 0.1),
Vectors.dense(5.0, 0.2),
Vectors.dense(10.0, 0.0),
Vectors.dense(15.0, 0.5),
Vectors.dense(32.0, 18.0),
Vectors.dense(30.1, 20.0),
Vectors.dense(-6.0, -6.0),
Vectors.dense(-10.0, -10.0))).map(v => TestRow(v)))
val model1 = new KMeans()
.setK(4)
.setInitMode(MLlibKMeans.K_MEANS_PARALLEL)
.setMaxIter(10)
.fit(df1)
val predictionDf1 = model1.transform(df1)
assert(predictionDf1.select("prediction").distinct().count() == 4)
val predictionsMap1 = predictionDf1.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
assert(predictionsMap1(Vectors.dense(0.1, 0.1)) ==
predictionsMap1(Vectors.dense(5.0, 0.2)) )
assert(predictionsMap1(Vectors.dense(10.0, 0.0)) ==
predictionsMap1(Vectors.dense(15.0, 0.5)) )
assert(predictionsMap1(Vectors.dense(32.0, 18.0)) ==
predictionsMap1(Vectors.dense(30.1, 20.0)))
assert(predictionsMap1(Vectors.dense(-6.0, -6.0)) ==
predictionsMap1(Vectors.dense(-10.0, -10.0)))
model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
// use same weight, should have the same result as no weight
val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array(
(Vectors.dense(0.1, 0.1), 2.0),
(Vectors.dense(5.0, 0.2), 2.0),
(Vectors.dense(10.0, 0.0), 2.0),
(Vectors.dense(15.0, 0.5), 2.0),
(Vectors.dense(32.0, 18.0), 2.0),
(Vectors.dense(30.1, 20.0), 2.0),
(Vectors.dense(-6.0, -6.0), 2.0),
(Vectors.dense(-10.0, -10.0), 2.0)))).toDF("features", "weightCol")
val model2 = new KMeans()
.setK(4)
.setInitMode(MLlibKMeans.K_MEANS_PARALLEL)
.setWeightCol("weightCol")
.setMaxIter(10)
.fit(df2)
val predictionDf2 = model2.transform(df2)
assert(predictionDf2.select("prediction").distinct().count() == 4)
val predictionsMap2 = predictionDf2.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
assert(predictionsMap2(Vectors.dense(0.1, 0.1)) ==
predictionsMap2(Vectors.dense(5.0, 0.2)))
assert(predictionsMap2(Vectors.dense(10.0, 0.0)) ==
predictionsMap2(Vectors.dense(15.0, 0.5)))
assert(predictionsMap2(Vectors.dense(32.0, 18.0)) ==
predictionsMap2(Vectors.dense(30.1, 20.0)))
assert(predictionsMap2(Vectors.dense(-6.0, -6.0)) ==
predictionsMap2(Vectors.dense(-10.0, -10.0)))
model2.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
assert(model1.clusterCenters === model2.clusterCenters)
}
}
object KMeansSuite {

View file

@ -423,7 +423,7 @@ class KMeansSummary(ClusteringSummary):
@inherit_doc
class _KMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTol,
HasDistanceMeasure):
HasDistanceMeasure, HasWeightCol):
"""
Params for :py:class:`KMeans` and :py:class:`KMeansModel`.
@ -517,12 +517,14 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
(the k-means|| algorithm by Bahmani et al).
>>> from pyspark.ml.linalg import Vectors
>>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
>>> df = spark.createDataFrame(data, ["features"])
>>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),
... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]
>>> df = spark.createDataFrame(data, ["features", "weighCol"])
>>> kmeans = KMeans(k=2)
>>> kmeans.setSeed(1)
KMeans...
>>> kmeans.setWeightCol("weighCol")
KMeans...
>>> kmeans.setMaxIter(10)
KMeans...
>>> kmeans.getMaxIter()
@ -552,7 +554,7 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
>>> summary.clusterSizes
[2, 2]
>>> summary.trainingCost
2.0
4.0
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
@ -574,11 +576,11 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
@keyword_only
def __init__(self, featuresCol="features", predictionCol="prediction", k=2,
initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
distanceMeasure="euclidean"):
distanceMeasure="euclidean", weightCol=None):
"""
__init__(self, featuresCol="features", predictionCol="prediction", k=2, \
initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
distanceMeasure="euclidean")
distanceMeasure="euclidean", weightCol=None)
"""
super(KMeans, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
@ -594,11 +596,11 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
@since("1.5.0")
def setParams(self, featuresCol="features", predictionCol="prediction", k=2,
initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
distanceMeasure="euclidean"):
distanceMeasure="euclidean", weightCol=None):
"""
setParams(self, featuresCol="features", predictionCol="prediction", k=2, \
initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
distanceMeasure="euclidean")
distanceMeasure="euclidean", weightCol=None)
Sets params for KMeans.
"""
@ -668,6 +670,13 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
"""
return self._set(tol=value)
@since("3.0.0")
def setWeightCol(self, value):
"""
Sets the value of :py:attr:`weightCol`.
"""
return self._set(weightCol=value)
@inherit_doc
class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol,