[SPARK-31007][ML] KMeans optimization based on triangle-inequality

### What changes were proposed in this pull request?
apply Lemma 1 in [Using the Triangle Inequality to Accelerate K-Means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf):

> Let x be a point, and let b and c be centers. If d(b,c)>=2d(x,b) then d(x,c) >= d(x,b);

It can be directly applied in EuclideanDistance, but not in CosineDistance.
However,  for CosineDistance we can luckily get a variant in the space of radian/angle.

### Why are the changes needed?
It help improving the performance of prediction and training (mostly)

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
existing testsuites

Closes #27758 from zhengruifeng/km_triangle.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
zhengruifeng 2020-04-24 11:24:15 -05:00 committed by Sean Owen
parent b10263b8e5
commit 0ede08bcb2
6 changed files with 390 additions and 45 deletions

View file

@ -18,7 +18,7 @@
package org.apache.spark.ml.impl package org.apache.spark.ml.impl
private[ml] object Utils { private[spark] object Utils {
lazy val EPSILON = { lazy val EPSILON = {
var eps = 1.0 var eps = 1.0
@ -27,4 +27,55 @@ private[ml] object Utils {
} }
eps eps
} }
/**
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
* into an n * n array representing the full symmetric matrix (column major).
*
* @param n The order of the n by n matrix.
* @param triangularValues The upper triangular part of the matrix packed in an array
* (column major).
* @return A dense matrix which represents the symmetric matrix in column major.
*/
def unpackUpperTriangular(
n: Int,
triangularValues: Array[Double]): Array[Double] = {
val symmetricValues = new Array[Double](n * n)
var r = 0
var i = 0
while (i < n) {
var j = 0
while (j <= i) {
symmetricValues(i * n + j) = triangularValues(r)
symmetricValues(j * n + i) = triangularValues(r)
r += 1
j += 1
}
i += 1
}
symmetricValues
}
/**
* Indexing in an array representing the upper triangular part of a matrix
* into an n * n array representing the full symmetric matrix (column major).
* val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues)
* val matrix = new DenseMatrix(n, n, symmetricValues)
* val index = indexUpperTriangularMatrix(n, i, j)
* then: symmetricValues(index) == matrix(i, j)
*
* @param n The order of the n by n matrix.
*/
def indexUpperTriangular(
n: Int,
i: Int,
j: Int): Int = {
require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.")
require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.")
if (i <= j) {
j * (j + 1) / 2 + i
} else {
i * (i + 1) / 2 + j
}
}
} }

View file

@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.impl.Utils.EPSILON import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
import org.apache.spark.ml.linalg._ import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
private[clustering] def unpackUpperTriangularMatrix( private[clustering] def unpackUpperTriangularMatrix(
n: Int, n: Int,
triangularValues: Array[Double]): DenseMatrix = { triangularValues: Array[Double]): DenseMatrix = {
val symmetricValues = new Array[Double](n * n) val symmetricValues = unpackUpperTriangular(n, triangularValues)
var r = 0
var i = 0
while (i < n) {
var j = 0
while (j <= i) {
symmetricValues(i * n + j) = triangularValues(r)
symmetricValues(j * n + i) = triangularValues(r)
r += 1
j += 1
}
i += 1
}
new DenseMatrix(n, n, symmetricValues) new DenseMatrix(n, n, symmetricValues)
} }

View file

@ -17,23 +17,125 @@
package org.apache.spark.mllib.clustering package org.apache.spark.mllib.clustering
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.impl.Utils.indexUpperTriangular
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.util.MLUtils
private[spark] abstract class DistanceMeasure extends Serializable { private[spark] abstract class DistanceMeasure extends Serializable {
/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
* @param distance distance between two centers
*/
def computeStatistics(distance: Double): Double
/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
*
* @return The packed upper triangular part of a symmetric matrix containing statistics,
* matrix(i,j) represents:
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
* computation. Given point x, let i be current closest center, and d be current best
* distance, if d < f(r), then we no longer need to compute the distance to center j;
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If distance
* between point x and center i is less than f(r), then center i is the closest center
* to point x.
*/
def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = {
val k = centers.length
if (k == 1) return Array(Double.NaN)
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
val diagValues = Array.fill(k)(Double.PositiveInfinity)
var i = 0
while (i < k) {
var j = i + 1
while (j < k) {
val d = distance(centers(i), centers(j))
val s = computeStatistics(d)
val index = indexUpperTriangular(k, i, j)
packedValues(index) = s
if (s < diagValues(i)) diagValues(i) = s
if (s < diagValues(j)) diagValues(j) = s
j += 1
}
i += 1
}
i = 0
while (i < k) {
val index = indexUpperTriangular(k, i, i)
packedValues(index) = diagValues(i)
i += 1
}
packedValues
}
/**
* Compute distance between centers in a distributed way.
*/
def computeStatisticsDistributedly(
sc: SparkContext,
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = {
val k = bcCenters.value.length
if (k == 1) return Array(Double.NaN)
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
val diagValues = Array.fill(k)(Double.PositiveInfinity)
val numParts = math.min(k, 1024)
sc.range(0, numParts, 1, numParts)
.mapPartitionsWithIndex { case (pid, _) =>
val centers = bcCenters.value
Iterator.range(0, k).flatMap { i =>
Iterator.range(i + 1, k).flatMap { j =>
val hash = (i, j).hashCode.abs
if (hash % numParts == pid) {
val d = distance(centers(i), centers(j))
val s = computeStatistics(d)
Iterator.single((i, j, s))
} else Iterator.empty
}
}
}.collect.foreach { case (i, j, s) =>
val index = indexUpperTriangular(k, i, j)
packedValues(index) = s
if (s < diagValues(i)) diagValues(i) = s
if (s < diagValues(j)) diagValues(j) = s
}
var i = 0
while (i < k) {
val index = indexUpperTriangular(k, i, i)
packedValues(index) = diagValues(i)
i += 1
}
packedValues
}
/** /**
* @return the index of the closest center to the given point, as well as the cost. * @return the index of the closest center to the given point, as well as the cost.
*/ */
def findClosest( def findClosest(
centers: TraversableOnce[VectorWithNorm], centers: Array[VectorWithNorm],
statistics: Array[Double],
point: VectorWithNorm): (Int, Double)
/**
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
centers: Array[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = { point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity var bestDistance = Double.PositiveInfinity
var bestIndex = 0 var bestIndex = 0
var i = 0 var i = 0
centers.foreach { center => while (i < centers.length) {
val center = centers(i)
val currentDistance = distance(center, point) val currentDistance = distance(center, point)
if (currentDistance < bestDistance) { if (currentDistance < bestDistance) {
bestDistance = currentDistance bestDistance = currentDistance
@ -48,7 +150,7 @@ private[spark] abstract class DistanceMeasure extends Serializable {
* @return the K-means cost of a given point against the given cluster centers. * @return the K-means cost of a given point against the given cluster centers.
*/ */
def pointCost( def pointCost(
centers: TraversableOnce[VectorWithNorm], centers: Array[VectorWithNorm],
point: VectorWithNorm): Double = { point: VectorWithNorm): Double = {
findClosest(centers, point)._2 findClosest(centers, point)._2
} }
@ -154,22 +256,79 @@ object DistanceMeasure {
} }
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
* @see <a href="https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf">Charles Elkan,
* Using the Triangle Inequality to Accelerate k-Means</a>
*
* @return One element used in statistics matrix to make matrix(i,j) represents:
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
* computation. Given point x, let i be current closest center, and d be current best
* squared distance, if d < r, then we no longer need to compute the distance to center
* j. matrix(i,j) equals to squared of half of Euclidean distance between centers i
* and j;
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If squared
* distance between point x and center i is less than r, then center i is the closest
* center to point x.
*/
override def computeStatistics(distance: Double): Double = {
0.25 * distance * distance
}
/**
* @return the index of the closest center to the given point, as well as the cost.
*/
override def findClosest(
centers: Array[VectorWithNorm],
statistics: Array[Double],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point)
if (bestDistance < statistics(0)) return (0, bestDistance)
val k = centers.length
var bestIndex = 0
var i = 1
while (i < k) {
val center = centers(i)
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
// distance computation.
val normDiff = center.norm - point.norm
val lowerBound = normDiff * normDiff
if (lowerBound < bestDistance) {
val index1 = indexUpperTriangular(k, i, bestIndex)
if (statistics(index1) < bestDistance) {
val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
val index2 = indexUpperTriangular(k, i, i)
if (d < statistics(index2)) return (i, d)
if (d < bestDistance) {
bestDistance = d
bestIndex = i
}
}
}
i += 1
}
(bestIndex, bestDistance)
}
/** /**
* @return the index of the closest center to the given point, as well as the squared distance. * @return the index of the closest center to the given point, as well as the squared distance.
*/ */
override def findClosest( override def findClosest(
centers: TraversableOnce[VectorWithNorm], centers: Array[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = { point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity var bestDistance = Double.PositiveInfinity
var bestIndex = 0 var bestIndex = 0
var i = 0 var i = 0
centers.foreach { center => while (i < centers.length) {
val center = centers(i)
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
// distance computation. // distance computation.
var lowerBoundOfSqDist = center.norm - point.norm var lowerBoundOfSqDist = center.norm - point.norm
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
if (lowerBoundOfSqDist < bestDistance) { if (lowerBoundOfSqDist < bestDistance) {
val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) val distance = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
if (distance < bestDistance) { if (distance < bestDistance) {
bestDistance = distance bestDistance = distance
bestIndex = i bestIndex = i
@ -234,6 +393,58 @@ private[spark] object EuclideanDistanceMeasure {
} }
private[spark] class CosineDistanceMeasure extends DistanceMeasure { private[spark] class CosineDistanceMeasure extends DistanceMeasure {
/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
*
* @return One element used in statistics matrix to make matrix(i,j) represents:
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
* computation. Given point x, let i be current closest center, and d be current best
* squared distance, if d < r, then we no longer need to compute the distance to center
* j. For Cosine distance, it is similar to Euclidean distance. However, radian/angle
* is used instead of Cosine distance to compute matrix(i,j): for centers i and j,
* compute the radian/angle between them, halving it, and converting it back to Cosine
* distance at the end;
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If Cosine
* distance between point x and center i is less than r, then center i is the closest
* center to point x.
*/
override def computeStatistics(distance: Double): Double = {
// d = 1 - cos(x)
// r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
1 - math.sqrt(1 - distance / 2)
}
/**
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
centers: Array[VectorWithNorm],
statistics: Array[Double],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = distance(centers(0), point)
if (bestDistance < statistics(0)) return (0, bestDistance)
val k = centers.length
var bestIndex = 0
var i = 1
while (i < k) {
val index1 = indexUpperTriangular(k, i, bestIndex)
if (statistics(index1) < bestDistance) {
val center = centers(i)
val d = distance(center, point)
val index2 = indexUpperTriangular(k, i, i)
if (d < statistics(index2)) return (i, d)
if (d < bestDistance) {
bestDistance = d
bestIndex = i
}
}
i += 1
}
(bestIndex, bestDistance)
}
/** /**
* @param v1: first vector * @param v1: first vector
* @param v2: second vector * @param v2: second vector

View file

@ -209,9 +209,7 @@ class KMeans private (
*/ */
@Since("0.8.0") @Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = { def run(data: RDD[Vector]): KMeansModel = {
val instances: RDD[(Vector, Double)] = data.map { val instances = data.map(point => (point, 1.0))
case (point) => (point, 1.0)
}
runWithWeight(instances, None) runWithWeight(instances, None)
} }
@ -260,6 +258,7 @@ class KMeans private (
initKMeansParallel(data, distanceMeasureInstance) initKMeansParallel(data, distanceMeasureInstance)
} }
} }
val numFeatures = centers.head.vector.size
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.") logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")
@ -269,34 +268,44 @@ class KMeans private (
val iterationStartTime = System.nanoTime() val iterationStartTime = System.nanoTime()
instr.foreach(_.logNumFeatures(centers.head.vector.size)) instr.foreach(_.logNumFeatures(numFeatures))
val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L
// Execute iterations of Lloyd's algorithm until converged // Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) { while (iteration < maxIterations && !converged) {
val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast(centers) val bcCenters = sc.broadcast(centers)
val stats = if (shouldDistributed) {
distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
} else {
distanceMeasureInstance.computeStatistics(centers)
}
val bcStats = sc.broadcast(stats)
val costAccum = sc.doubleAccumulator
// Find the new centers // Find the new centers
val collected = data.mapPartitions { points => val collected = data.mapPartitions { points =>
val thisCenters = bcCenters.value val centers = bcCenters.value
val dims = thisCenters.head.vector.size val stats = bcStats.value
val dims = centers.head.vector.size
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) val sums = Array.fill(centers.length)(Vectors.zeros(dims))
// clusterWeightSum is needed to calculate cluster center // clusterWeightSum is needed to calculate cluster center
// cluster center = // cluster center =
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ... // sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
val clusterWeightSum = Array.ofDim[Double](thisCenters.length) val clusterWeightSum = Array.ofDim[Double](centers.length)
points.foreach { point => points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) val (bestCenter, cost) = distanceMeasureInstance.findClosest(centers, stats, point)
costAccum.add(cost * point.weight) costAccum.add(cost * point.weight)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
clusterWeightSum(bestCenter) += point.weight clusterWeightSum(bestCenter) += point.weight
} }
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0) Iterator.tabulate(centers.length)(j => (j, (sums(j), clusterWeightSum(j))))
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator .filter(_._2._2 > 0)
}.reduceByKey { (sumweight1, sumweight2) => }.reduceByKey { (sumweight1, sumweight2) =>
axpy(1.0, sumweight2._1, sumweight1._1) axpy(1.0, sumweight2._1, sumweight1._1)
(sumweight1._1, sumweight1._2 + sumweight2._2) (sumweight1._1, sumweight1._2 + sumweight2._2)
@ -307,15 +316,13 @@ class KMeans private (
instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum)) instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum))
} }
val newCenters = collected.mapValues { case (sum, weightSum) =>
distanceMeasureInstance.centroid(sum, weightSum)
}
bcCenters.destroy() bcCenters.destroy()
bcStats.destroy()
// Update the cluster centers and costs // Update the cluster centers and costs
converged = true converged = true
newCenters.foreach { case (j, newCenter) => collected.foreach { case (j, (sum, weightSum)) =>
val newCenter = distanceMeasureInstance.centroid(sum, weightSum)
if (converged && if (converged &&
!distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) { !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
converged = false converged = false
@ -324,6 +331,7 @@ class KMeans private (
} }
cost = costAccum.value cost = costAccum.value
instr.foreach(_.logNamedValue(s"Cost@iter=$iteration", s"$cost"))
iteration += 1 iteration += 1
} }
@ -372,7 +380,7 @@ class KMeans private (
require(sample.nonEmpty, s"No samples available from $data") require(sample.nonEmpty, s"No samples available from $data")
val centers = ArrayBuffer[VectorWithNorm]() val centers = ArrayBuffer[VectorWithNorm]()
var newCenters = Seq(sample.head.toDense) var newCenters = Array(sample.head.toDense)
centers ++= newCenters centers ++= newCenters
// On each step, sample 2 * k points on average with probability proportional // On each step, sample 2 * k points on average with probability proportional
@ -404,10 +412,10 @@ class KMeans private (
costs.unpersist() costs.unpersist()
bcNewCentersList.foreach(_.destroy()) bcNewCentersList.foreach(_.destroy())
val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)) val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)).toArray
if (distinctCenters.size <= k) { if (distinctCenters.length <= k) {
distinctCenters.toArray distinctCenters
} else { } else {
// Finally, we might have a set of more than k distinct candidate centers; weight each // Finally, we might have a set of more than k distinct candidate centers; weight each
// candidate by the number of points in the dataset mapping to it and run a local k-means++ // candidate by the number of points in the dataset mapping to it and run a local k-means++
@ -420,7 +428,7 @@ class KMeans private (
bcCenters.destroy() bcCenters.destroy()
val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30) LocalKMeans.kMeansPlusPlus(0, distinctCenters, myWeights, k, 30)
} }
} }
} }

View file

@ -48,6 +48,13 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
@transient private lazy val clusterCentersWithNorm = @transient private lazy val clusterCentersWithNorm =
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
// TODO: computation of statistics may take seconds, so save it to KMeansModel in training
@transient private lazy val statistics = if (clusterCenters == null) {
null
} else {
distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)
}
@Since("2.4.0") @Since("2.4.0")
private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) =
this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1) this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1)
@ -73,7 +80,8 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
*/ */
@Since("0.8.0") @Since("0.8.0")
def predict(point: Vector): Int = { def predict(point: Vector): Int = {
distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 distanceMeasureInstance.findClosest(clusterCentersWithNorm, statistics,
new VectorWithNorm(point))._1
} }
/** /**
@ -82,8 +90,10 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
@Since("1.0.0") @Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = { def predict(points: RDD[Vector]): RDD[Int] = {
val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
val bcStatistics = points.context.broadcast(statistics)
points.map(p => points.map(p =>
distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) distanceMeasureInstance.findClosest(bcCentersWithNorm.value,
bcStatistics.value, new VectorWithNorm(p))._1)
} }
/** /**

View file

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.clustering
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class DistanceMeasureSuite extends SparkFunSuite with MLlibTestSparkContext {
private val seed = 42
private val k = 10
private val dim = 8
private var centers: Array[VectorWithNorm] = _
private var data: Array[VectorWithNorm] = _
override def beforeAll(): Unit = {
super.beforeAll()
val rng = new Random(seed)
centers = Array.tabulate(k) { i =>
val values = Array.fill(dim)(rng.nextGaussian)
new VectorWithNorm(Vectors.dense(values))
}
data = Array.tabulate(1000) { i =>
val values = Array.fill(dim)(rng.nextGaussian)
new VectorWithNorm(Vectors.dense(values))
}
}
test("predict with statistics") {
Seq(DistanceMeasure.COSINE, DistanceMeasure.EUCLIDEAN).foreach { distanceMeasure =>
val distance = DistanceMeasure.decodeFromString(distanceMeasure)
val statistics = distance.computeStatistics(centers)
data.foreach { point =>
val (index1, cost1) = distance.findClosest(centers, point)
val (index2, cost2) = distance.findClosest(centers, statistics, point)
assert(index1 == index2)
assert(cost1 ~== cost2 relTol 1E-10)
}
}
}
test("compute statistics distributedly") {
Seq(DistanceMeasure.COSINE, DistanceMeasure.EUCLIDEAN).foreach { distanceMeasure =>
val distance = DistanceMeasure.decodeFromString(distanceMeasure)
val statistics1 = distance.computeStatistics(centers)
val sc = spark.sparkContext
val bcCenters = sc.broadcast(centers)
val statistics2 = distance.computeStatisticsDistributedly(sc, bcCenters)
bcCenters.destroy()
assert(Vectors.dense(statistics1) ~== Vectors.dense(statistics2) relTol 1E-10)
}
}
}