[SPARK-31007][ML] KMeans optimization based on triangle-inequality
### What changes were proposed in this pull request? apply Lemma 1 in [Using the Triangle Inequality to Accelerate K-Means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf): > Let x be a point, and let b and c be centers. If d(b,c)>=2d(x,b) then d(x,c) >= d(x,b); It can be directly applied in EuclideanDistance, but not in CosineDistance. However, for CosineDistance we can luckily get a variant in the space of radian/angle. ### Why are the changes needed? It help improving the performance of prediction and training (mostly) ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes #27758 from zhengruifeng/km_triangle. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
b10263b8e5
commit
0ede08bcb2
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.ml.impl
|
||||
|
||||
|
||||
private[ml] object Utils {
|
||||
private[spark] object Utils {
|
||||
|
||||
lazy val EPSILON = {
|
||||
var eps = 1.0
|
||||
|
@ -27,4 +27,55 @@ private[ml] object Utils {
|
|||
}
|
||||
eps
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
|
||||
* into an n * n array representing the full symmetric matrix (column major).
|
||||
*
|
||||
* @param n The order of the n by n matrix.
|
||||
* @param triangularValues The upper triangular part of the matrix packed in an array
|
||||
* (column major).
|
||||
* @return A dense matrix which represents the symmetric matrix in column major.
|
||||
*/
|
||||
def unpackUpperTriangular(
|
||||
n: Int,
|
||||
triangularValues: Array[Double]): Array[Double] = {
|
||||
val symmetricValues = new Array[Double](n * n)
|
||||
var r = 0
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
var j = 0
|
||||
while (j <= i) {
|
||||
symmetricValues(i * n + j) = triangularValues(r)
|
||||
symmetricValues(j * n + i) = triangularValues(r)
|
||||
r += 1
|
||||
j += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
symmetricValues
|
||||
}
|
||||
|
||||
/**
|
||||
* Indexing in an array representing the upper triangular part of a matrix
|
||||
* into an n * n array representing the full symmetric matrix (column major).
|
||||
* val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues)
|
||||
* val matrix = new DenseMatrix(n, n, symmetricValues)
|
||||
* val index = indexUpperTriangularMatrix(n, i, j)
|
||||
* then: symmetricValues(index) == matrix(i, j)
|
||||
*
|
||||
* @param n The order of the n by n matrix.
|
||||
*/
|
||||
def indexUpperTriangular(
|
||||
n: Int,
|
||||
i: Int,
|
||||
j: Int): Int = {
|
||||
require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.")
|
||||
require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.")
|
||||
if (i <= j) {
|
||||
j * (j + 1) / 2 + i
|
||||
} else {
|
||||
i * (i + 1) / 2 + j
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.impl.Utils.EPSILON
|
||||
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
|
|||
private[clustering] def unpackUpperTriangularMatrix(
|
||||
n: Int,
|
||||
triangularValues: Array[Double]): DenseMatrix = {
|
||||
val symmetricValues = new Array[Double](n * n)
|
||||
var r = 0
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
var j = 0
|
||||
while (j <= i) {
|
||||
symmetricValues(i * n + j) = triangularValues(r)
|
||||
symmetricValues(j * n + i) = triangularValues(r)
|
||||
r += 1
|
||||
j += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
val symmetricValues = unpackUpperTriangular(n, triangularValues)
|
||||
new DenseMatrix(n, n, symmetricValues)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,23 +17,125 @@
|
|||
|
||||
package org.apache.spark.mllib.clustering
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.impl.Utils.indexUpperTriangular
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
||||
private[spark] abstract class DistanceMeasure extends Serializable {
|
||||
|
||||
/**
|
||||
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
|
||||
* @param distance distance between two centers
|
||||
*/
|
||||
def computeStatistics(distance: Double): Double
|
||||
|
||||
/**
|
||||
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
|
||||
*
|
||||
* @return The packed upper triangular part of a symmetric matrix containing statistics,
|
||||
* matrix(i,j) represents:
|
||||
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
|
||||
* computation. Given point x, let i be current closest center, and d be current best
|
||||
* distance, if d < f(r), then we no longer need to compute the distance to center j;
|
||||
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If distance
|
||||
* between point x and center i is less than f(r), then center i is the closest center
|
||||
* to point x.
|
||||
*/
|
||||
def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = {
|
||||
val k = centers.length
|
||||
if (k == 1) return Array(Double.NaN)
|
||||
|
||||
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
|
||||
val diagValues = Array.fill(k)(Double.PositiveInfinity)
|
||||
var i = 0
|
||||
while (i < k) {
|
||||
var j = i + 1
|
||||
while (j < k) {
|
||||
val d = distance(centers(i), centers(j))
|
||||
val s = computeStatistics(d)
|
||||
val index = indexUpperTriangular(k, i, j)
|
||||
packedValues(index) = s
|
||||
if (s < diagValues(i)) diagValues(i) = s
|
||||
if (s < diagValues(j)) diagValues(j) = s
|
||||
j += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
||||
i = 0
|
||||
while (i < k) {
|
||||
val index = indexUpperTriangular(k, i, i)
|
||||
packedValues(index) = diagValues(i)
|
||||
i += 1
|
||||
}
|
||||
packedValues
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute distance between centers in a distributed way.
|
||||
*/
|
||||
def computeStatisticsDistributedly(
|
||||
sc: SparkContext,
|
||||
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = {
|
||||
val k = bcCenters.value.length
|
||||
if (k == 1) return Array(Double.NaN)
|
||||
|
||||
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
|
||||
val diagValues = Array.fill(k)(Double.PositiveInfinity)
|
||||
|
||||
val numParts = math.min(k, 1024)
|
||||
sc.range(0, numParts, 1, numParts)
|
||||
.mapPartitionsWithIndex { case (pid, _) =>
|
||||
val centers = bcCenters.value
|
||||
Iterator.range(0, k).flatMap { i =>
|
||||
Iterator.range(i + 1, k).flatMap { j =>
|
||||
val hash = (i, j).hashCode.abs
|
||||
if (hash % numParts == pid) {
|
||||
val d = distance(centers(i), centers(j))
|
||||
val s = computeStatistics(d)
|
||||
Iterator.single((i, j, s))
|
||||
} else Iterator.empty
|
||||
}
|
||||
}
|
||||
}.collect.foreach { case (i, j, s) =>
|
||||
val index = indexUpperTriangular(k, i, j)
|
||||
packedValues(index) = s
|
||||
if (s < diagValues(i)) diagValues(i) = s
|
||||
if (s < diagValues(j)) diagValues(j) = s
|
||||
}
|
||||
|
||||
var i = 0
|
||||
while (i < k) {
|
||||
val index = indexUpperTriangular(k, i, i)
|
||||
packedValues(index) = diagValues(i)
|
||||
i += 1
|
||||
}
|
||||
packedValues
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the cost.
|
||||
*/
|
||||
def findClosest(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
centers: Array[VectorWithNorm],
|
||||
statistics: Array[Double],
|
||||
point: VectorWithNorm): (Int, Double)
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the cost.
|
||||
*/
|
||||
def findClosest(
|
||||
centers: Array[VectorWithNorm],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = Double.PositiveInfinity
|
||||
var bestIndex = 0
|
||||
var i = 0
|
||||
centers.foreach { center =>
|
||||
while (i < centers.length) {
|
||||
val center = centers(i)
|
||||
val currentDistance = distance(center, point)
|
||||
if (currentDistance < bestDistance) {
|
||||
bestDistance = currentDistance
|
||||
|
@ -48,7 +150,7 @@ private[spark] abstract class DistanceMeasure extends Serializable {
|
|||
* @return the K-means cost of a given point against the given cluster centers.
|
||||
*/
|
||||
def pointCost(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
centers: Array[VectorWithNorm],
|
||||
point: VectorWithNorm): Double = {
|
||||
findClosest(centers, point)._2
|
||||
}
|
||||
|
@ -154,22 +256,79 @@ object DistanceMeasure {
|
|||
}
|
||||
|
||||
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
|
||||
|
||||
/**
|
||||
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
|
||||
* @see <a href="https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf">Charles Elkan,
|
||||
* Using the Triangle Inequality to Accelerate k-Means</a>
|
||||
*
|
||||
* @return One element used in statistics matrix to make matrix(i,j) represents:
|
||||
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
|
||||
* computation. Given point x, let i be current closest center, and d be current best
|
||||
* squared distance, if d < r, then we no longer need to compute the distance to center
|
||||
* j. matrix(i,j) equals to squared of half of Euclidean distance between centers i
|
||||
* and j;
|
||||
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If squared
|
||||
* distance between point x and center i is less than r, then center i is the closest
|
||||
* center to point x.
|
||||
*/
|
||||
override def computeStatistics(distance: Double): Double = {
|
||||
0.25 * distance * distance
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the cost.
|
||||
*/
|
||||
override def findClosest(
|
||||
centers: Array[VectorWithNorm],
|
||||
statistics: Array[Double],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point)
|
||||
if (bestDistance < statistics(0)) return (0, bestDistance)
|
||||
|
||||
val k = centers.length
|
||||
var bestIndex = 0
|
||||
var i = 1
|
||||
while (i < k) {
|
||||
val center = centers(i)
|
||||
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
|
||||
// distance computation.
|
||||
val normDiff = center.norm - point.norm
|
||||
val lowerBound = normDiff * normDiff
|
||||
if (lowerBound < bestDistance) {
|
||||
val index1 = indexUpperTriangular(k, i, bestIndex)
|
||||
if (statistics(index1) < bestDistance) {
|
||||
val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
|
||||
val index2 = indexUpperTriangular(k, i, i)
|
||||
if (d < statistics(index2)) return (i, d)
|
||||
if (d < bestDistance) {
|
||||
bestDistance = d
|
||||
bestIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(bestIndex, bestDistance)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the squared distance.
|
||||
*/
|
||||
override def findClosest(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
centers: Array[VectorWithNorm],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = Double.PositiveInfinity
|
||||
var bestIndex = 0
|
||||
var i = 0
|
||||
centers.foreach { center =>
|
||||
while (i < centers.length) {
|
||||
val center = centers(i)
|
||||
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
|
||||
// distance computation.
|
||||
var lowerBoundOfSqDist = center.norm - point.norm
|
||||
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
|
||||
if (lowerBoundOfSqDist < bestDistance) {
|
||||
val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
|
||||
val distance = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
|
||||
if (distance < bestDistance) {
|
||||
bestDistance = distance
|
||||
bestIndex = i
|
||||
|
@ -234,6 +393,58 @@ private[spark] object EuclideanDistanceMeasure {
|
|||
}
|
||||
|
||||
private[spark] class CosineDistanceMeasure extends DistanceMeasure {
|
||||
|
||||
/**
|
||||
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
|
||||
*
|
||||
* @return One element used in statistics matrix to make matrix(i,j) represents:
|
||||
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
|
||||
* computation. Given point x, let i be current closest center, and d be current best
|
||||
* squared distance, if d < r, then we no longer need to compute the distance to center
|
||||
* j. For Cosine distance, it is similar to Euclidean distance. However, radian/angle
|
||||
* is used instead of Cosine distance to compute matrix(i,j): for centers i and j,
|
||||
* compute the radian/angle between them, halving it, and converting it back to Cosine
|
||||
* distance at the end;
|
||||
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If Cosine
|
||||
* distance between point x and center i is less than r, then center i is the closest
|
||||
* center to point x.
|
||||
*/
|
||||
override def computeStatistics(distance: Double): Double = {
|
||||
// d = 1 - cos(x)
|
||||
// r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
|
||||
1 - math.sqrt(1 - distance / 2)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the cost.
|
||||
*/
|
||||
def findClosest(
|
||||
centers: Array[VectorWithNorm],
|
||||
statistics: Array[Double],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = distance(centers(0), point)
|
||||
if (bestDistance < statistics(0)) return (0, bestDistance)
|
||||
|
||||
val k = centers.length
|
||||
var bestIndex = 0
|
||||
var i = 1
|
||||
while (i < k) {
|
||||
val index1 = indexUpperTriangular(k, i, bestIndex)
|
||||
if (statistics(index1) < bestDistance) {
|
||||
val center = centers(i)
|
||||
val d = distance(center, point)
|
||||
val index2 = indexUpperTriangular(k, i, i)
|
||||
if (d < statistics(index2)) return (i, d)
|
||||
if (d < bestDistance) {
|
||||
bestDistance = d
|
||||
bestIndex = i
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(bestIndex, bestDistance)
|
||||
}
|
||||
|
||||
/**
|
||||
* @param v1: first vector
|
||||
* @param v2: second vector
|
||||
|
|
|
@ -209,9 +209,7 @@ class KMeans private (
|
|||
*/
|
||||
@Since("0.8.0")
|
||||
def run(data: RDD[Vector]): KMeansModel = {
|
||||
val instances: RDD[(Vector, Double)] = data.map {
|
||||
case (point) => (point, 1.0)
|
||||
}
|
||||
val instances = data.map(point => (point, 1.0))
|
||||
runWithWeight(instances, None)
|
||||
}
|
||||
|
||||
|
@ -260,6 +258,7 @@ class KMeans private (
|
|||
initKMeansParallel(data, distanceMeasureInstance)
|
||||
}
|
||||
}
|
||||
val numFeatures = centers.head.vector.size
|
||||
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
|
||||
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")
|
||||
|
||||
|
@ -269,34 +268,44 @@ class KMeans private (
|
|||
|
||||
val iterationStartTime = System.nanoTime()
|
||||
|
||||
instr.foreach(_.logNumFeatures(centers.head.vector.size))
|
||||
instr.foreach(_.logNumFeatures(numFeatures))
|
||||
|
||||
val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L
|
||||
|
||||
// Execute iterations of Lloyd's algorithm until converged
|
||||
while (iteration < maxIterations && !converged) {
|
||||
val costAccum = sc.doubleAccumulator
|
||||
val bcCenters = sc.broadcast(centers)
|
||||
val stats = if (shouldDistributed) {
|
||||
distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
|
||||
} else {
|
||||
distanceMeasureInstance.computeStatistics(centers)
|
||||
}
|
||||
val bcStats = sc.broadcast(stats)
|
||||
|
||||
val costAccum = sc.doubleAccumulator
|
||||
|
||||
// Find the new centers
|
||||
val collected = data.mapPartitions { points =>
|
||||
val thisCenters = bcCenters.value
|
||||
val dims = thisCenters.head.vector.size
|
||||
val centers = bcCenters.value
|
||||
val stats = bcStats.value
|
||||
val dims = centers.head.vector.size
|
||||
|
||||
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
|
||||
val sums = Array.fill(centers.length)(Vectors.zeros(dims))
|
||||
|
||||
// clusterWeightSum is needed to calculate cluster center
|
||||
// cluster center =
|
||||
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
|
||||
val clusterWeightSum = Array.ofDim[Double](thisCenters.length)
|
||||
val clusterWeightSum = Array.ofDim[Double](centers.length)
|
||||
|
||||
points.foreach { point =>
|
||||
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
|
||||
val (bestCenter, cost) = distanceMeasureInstance.findClosest(centers, stats, point)
|
||||
costAccum.add(cost * point.weight)
|
||||
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
|
||||
clusterWeightSum(bestCenter) += point.weight
|
||||
}
|
||||
|
||||
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0)
|
||||
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator
|
||||
Iterator.tabulate(centers.length)(j => (j, (sums(j), clusterWeightSum(j))))
|
||||
.filter(_._2._2 > 0)
|
||||
}.reduceByKey { (sumweight1, sumweight2) =>
|
||||
axpy(1.0, sumweight2._1, sumweight1._1)
|
||||
(sumweight1._1, sumweight1._2 + sumweight2._2)
|
||||
|
@ -307,15 +316,13 @@ class KMeans private (
|
|||
instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum))
|
||||
}
|
||||
|
||||
val newCenters = collected.mapValues { case (sum, weightSum) =>
|
||||
distanceMeasureInstance.centroid(sum, weightSum)
|
||||
}
|
||||
|
||||
bcCenters.destroy()
|
||||
bcStats.destroy()
|
||||
|
||||
// Update the cluster centers and costs
|
||||
converged = true
|
||||
newCenters.foreach { case (j, newCenter) =>
|
||||
collected.foreach { case (j, (sum, weightSum)) =>
|
||||
val newCenter = distanceMeasureInstance.centroid(sum, weightSum)
|
||||
if (converged &&
|
||||
!distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
|
||||
converged = false
|
||||
|
@ -324,6 +331,7 @@ class KMeans private (
|
|||
}
|
||||
|
||||
cost = costAccum.value
|
||||
instr.foreach(_.logNamedValue(s"Cost@iter=$iteration", s"$cost"))
|
||||
iteration += 1
|
||||
}
|
||||
|
||||
|
@ -372,7 +380,7 @@ class KMeans private (
|
|||
require(sample.nonEmpty, s"No samples available from $data")
|
||||
|
||||
val centers = ArrayBuffer[VectorWithNorm]()
|
||||
var newCenters = Seq(sample.head.toDense)
|
||||
var newCenters = Array(sample.head.toDense)
|
||||
centers ++= newCenters
|
||||
|
||||
// On each step, sample 2 * k points on average with probability proportional
|
||||
|
@ -404,10 +412,10 @@ class KMeans private (
|
|||
costs.unpersist()
|
||||
bcNewCentersList.foreach(_.destroy())
|
||||
|
||||
val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_))
|
||||
val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)).toArray
|
||||
|
||||
if (distinctCenters.size <= k) {
|
||||
distinctCenters.toArray
|
||||
if (distinctCenters.length <= k) {
|
||||
distinctCenters
|
||||
} else {
|
||||
// Finally, we might have a set of more than k distinct candidate centers; weight each
|
||||
// candidate by the number of points in the dataset mapping to it and run a local k-means++
|
||||
|
@ -420,7 +428,7 @@ class KMeans private (
|
|||
bcCenters.destroy()
|
||||
|
||||
val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
|
||||
LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30)
|
||||
LocalKMeans.kMeansPlusPlus(0, distinctCenters, myWeights, k, 30)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,13 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
|
|||
@transient private lazy val clusterCentersWithNorm =
|
||||
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
|
||||
|
||||
// TODO: computation of statistics may take seconds, so save it to KMeansModel in training
|
||||
@transient private lazy val statistics = if (clusterCenters == null) {
|
||||
null
|
||||
} else {
|
||||
distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)
|
||||
}
|
||||
|
||||
@Since("2.4.0")
|
||||
private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) =
|
||||
this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1)
|
||||
|
@ -73,7 +80,8 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
|
|||
*/
|
||||
@Since("0.8.0")
|
||||
def predict(point: Vector): Int = {
|
||||
distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
|
||||
distanceMeasureInstance.findClosest(clusterCentersWithNorm, statistics,
|
||||
new VectorWithNorm(point))._1
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -82,8 +90,10 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
|
|||
@Since("1.0.0")
|
||||
def predict(points: RDD[Vector]): RDD[Int] = {
|
||||
val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
|
||||
val bcStatistics = points.context.broadcast(statistics)
|
||||
points.map(p =>
|
||||
distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
|
||||
distanceMeasureInstance.findClosest(bcCentersWithNorm.value,
|
||||
bcStatistics.value, new VectorWithNorm(p))._1)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.clustering
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
|
||||
class DistanceMeasureSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
private val seed = 42
|
||||
private val k = 10
|
||||
private val dim = 8
|
||||
|
||||
private var centers: Array[VectorWithNorm] = _
|
||||
|
||||
private var data: Array[VectorWithNorm] = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
|
||||
val rng = new Random(seed)
|
||||
|
||||
centers = Array.tabulate(k) { i =>
|
||||
val values = Array.fill(dim)(rng.nextGaussian)
|
||||
new VectorWithNorm(Vectors.dense(values))
|
||||
}
|
||||
|
||||
data = Array.tabulate(1000) { i =>
|
||||
val values = Array.fill(dim)(rng.nextGaussian)
|
||||
new VectorWithNorm(Vectors.dense(values))
|
||||
}
|
||||
}
|
||||
|
||||
test("predict with statistics") {
|
||||
Seq(DistanceMeasure.COSINE, DistanceMeasure.EUCLIDEAN).foreach { distanceMeasure =>
|
||||
val distance = DistanceMeasure.decodeFromString(distanceMeasure)
|
||||
val statistics = distance.computeStatistics(centers)
|
||||
data.foreach { point =>
|
||||
val (index1, cost1) = distance.findClosest(centers, point)
|
||||
val (index2, cost2) = distance.findClosest(centers, statistics, point)
|
||||
assert(index1 == index2)
|
||||
assert(cost1 ~== cost2 relTol 1E-10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("compute statistics distributedly") {
|
||||
Seq(DistanceMeasure.COSINE, DistanceMeasure.EUCLIDEAN).foreach { distanceMeasure =>
|
||||
val distance = DistanceMeasure.decodeFromString(distanceMeasure)
|
||||
val statistics1 = distance.computeStatistics(centers)
|
||||
val sc = spark.sparkContext
|
||||
val bcCenters = sc.broadcast(centers)
|
||||
val statistics2 = distance.computeStatisticsDistributedly(sc, bcCenters)
|
||||
bcCenters.destroy()
|
||||
assert(Vectors.dense(statistics1) ~== Vectors.dense(statistics2) relTol 1E-10)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue