[SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance

## What changes were proposed in this pull request?

In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters.

The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance.

## How was this patch tested?

existing/improved UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20518 from mgaido91/SPARK-22119_followup.
This commit is contained in:
Marco Gaido 2018-02-11 20:15:30 -06:00 committed by Sean Owen
parent 4bbd7443eb
commit c0c902aedc
2 changed files with 62 additions and 7 deletions

View file

@ -310,8 +310,7 @@ class KMeans private (
points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
val sum = sums(bestCenter)
axpy(1.0, point.vector, sum)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
}
@ -319,10 +318,9 @@ class KMeans private (
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.mapValues { case (sum, count) =>
scal(1.0 / count, sum)
new VectorWithNorm(sum)
}.collectAsMap()
}.collectAsMap().mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}
bcCenters.destroy(blocking = false)
@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable {
v1: VectorWithNorm,
v2: VectorWithNorm): Double
/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0, point.vector, sum)
}
/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
new VectorWithNorm(sum)
}
}
@Since("2.4.0")
@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @return the cosine distance between the two input vectors
*/
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
}
/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0 / point.norm, point.vector, sum)
}
/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
val norm = Vectors.norm(sum, 2)
scal(1.0 / norm, sum)
new VectorWithNorm(sum, 1)
}
}

View file

@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
predictionsMap(Vectors.dense(-100.0, 90.0)))
model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
}
test("KMeans with cosine distance is not supported for 0-length vectors") {
val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
Vectors.dense(0.0, 0.0),
Vectors.dense(10.0, 10.0),
Vectors.dense(1.0, 0.5)
)).map(v => TestRow(v)))
val e = intercept[SparkException](model.fit(df))
assert(e.getCause.isInstanceOf[AssertionError])
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
}
test("read/write") {