From 0ede08bcb21266739aab86b8af3228adc8239eb0 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Fri, 24 Apr 2020 11:24:15 -0500 Subject: [PATCH] [SPARK-31007][ML] KMeans optimization based on triangle-inequality ### What changes were proposed in this pull request? apply Lemma 1 in [Using the Triangle Inequality to Accelerate K-Means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf): > Let x be a point, and let b and c be centers. If d(b,c)>=2d(x,b) then d(x,c) >= d(x,b); It can be directly applied in EuclideanDistance, but not in CosineDistance. However, for CosineDistance we can luckily get a variant in the space of radian/angle. ### Why are the changes needed? It help improving the performance of prediction and training (mostly) ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes #27758 from zhengruifeng/km_triangle. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../org/apache/spark/ml/impl/Utils.scala | 53 ++++- .../spark/ml/clustering/GaussianMixture.scala | 16 +- .../mllib/clustering/DistanceMeasure.scala | 223 +++++++++++++++++- .../spark/mllib/clustering/KMeans.scala | 52 ++-- .../spark/mllib/clustering/KMeansModel.scala | 14 +- .../clustering/DistanceMeasureSuite.scala | 77 ++++++ 6 files changed, 390 insertions(+), 45 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/DistanceMeasureSuite.scala diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala index 112de982e4..ee3e99c0a8 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.impl -private[ml] object Utils { +private[spark] object Utils { lazy val EPSILON = { var eps = 1.0 @@ -27,4 +27,55 @@ private[ml] object Utils { } eps } + + /** + * Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix + * into an n * n array representing the full symmetric matrix (column major). + * + * @param n The order of the n by n matrix. + * @param triangularValues The upper triangular part of the matrix packed in an array + * (column major). + * @return A dense matrix which represents the symmetric matrix in column major. + */ + def unpackUpperTriangular( + n: Int, + triangularValues: Array[Double]): Array[Double] = { + val symmetricValues = new Array[Double](n * n) + var r = 0 + var i = 0 + while (i < n) { + var j = 0 + while (j <= i) { + symmetricValues(i * n + j) = triangularValues(r) + symmetricValues(j * n + i) = triangularValues(r) + r += 1 + j += 1 + } + i += 1 + } + symmetricValues + } + + /** + * Indexing in an array representing the upper triangular part of a matrix + * into an n * n array representing the full symmetric matrix (column major). + * val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues) + * val matrix = new DenseMatrix(n, n, symmetricValues) + * val index = indexUpperTriangularMatrix(n, i, j) + * then: symmetricValues(index) == matrix(i, j) + * + * @param n The order of the n by n matrix. + */ + def indexUpperTriangular( + n: Int, + i: Int, + j: Int): Int = { + require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.") + require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.") + if (i <= j) { + j * (j + 1) / 2 + i + } else { + i * (i + 1) / 2 + j + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f490faf084..1c4560aa5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.impl.Utils.EPSILON +import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { private[clustering] def unpackUpperTriangularMatrix( n: Int, triangularValues: Array[Double]): DenseMatrix = { - val symmetricValues = new Array[Double](n * n) - var r = 0 - var i = 0 - while (i < n) { - var j = 0 - while (j <= i) { - symmetricValues(i * n + j) = triangularValues(r) - symmetricValues(j * n + i) = triangularValues(r) - r += 1 - j += 1 - } - i += 1 - } + val symmetricValues = unpackUpperTriangular(n, triangularValues) new DenseMatrix(n, n, symmetricValues) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala index e83dd3723b..bffed61c29 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -17,23 +17,125 @@ package org.apache.spark.mllib.clustering +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.impl.Utils.indexUpperTriangular import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils private[spark] abstract class DistanceMeasure extends Serializable { + /** + * Statistics used in triangle inequality to obtain useful bounds to find closest centers. + * @param distance distance between two centers + */ + def computeStatistics(distance: Double): Double + + /** + * Statistics used in triangle inequality to obtain useful bounds to find closest centers. + * + * @return The packed upper triangular part of a symmetric matrix containing statistics, + * matrix(i,j) represents: + * 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance + * computation. Given point x, let i be current closest center, and d be current best + * distance, if d < f(r), then we no longer need to compute the distance to center j; + * 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If distance + * between point x and center i is less than f(r), then center i is the closest center + * to point x. + */ + def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = { + val k = centers.length + if (k == 1) return Array(Double.NaN) + + val packedValues = Array.ofDim[Double](k * (k + 1) / 2) + val diagValues = Array.fill(k)(Double.PositiveInfinity) + var i = 0 + while (i < k) { + var j = i + 1 + while (j < k) { + val d = distance(centers(i), centers(j)) + val s = computeStatistics(d) + val index = indexUpperTriangular(k, i, j) + packedValues(index) = s + if (s < diagValues(i)) diagValues(i) = s + if (s < diagValues(j)) diagValues(j) = s + j += 1 + } + i += 1 + } + + i = 0 + while (i < k) { + val index = indexUpperTriangular(k, i, i) + packedValues(index) = diagValues(i) + i += 1 + } + packedValues + } + + /** + * Compute distance between centers in a distributed way. + */ + def computeStatisticsDistributedly( + sc: SparkContext, + bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = { + val k = bcCenters.value.length + if (k == 1) return Array(Double.NaN) + + val packedValues = Array.ofDim[Double](k * (k + 1) / 2) + val diagValues = Array.fill(k)(Double.PositiveInfinity) + + val numParts = math.min(k, 1024) + sc.range(0, numParts, 1, numParts) + .mapPartitionsWithIndex { case (pid, _) => + val centers = bcCenters.value + Iterator.range(0, k).flatMap { i => + Iterator.range(i + 1, k).flatMap { j => + val hash = (i, j).hashCode.abs + if (hash % numParts == pid) { + val d = distance(centers(i), centers(j)) + val s = computeStatistics(d) + Iterator.single((i, j, s)) + } else Iterator.empty + } + } + }.collect.foreach { case (i, j, s) => + val index = indexUpperTriangular(k, i, j) + packedValues(index) = s + if (s < diagValues(i)) diagValues(i) = s + if (s < diagValues(j)) diagValues(j) = s + } + + var i = 0 + while (i < k) { + val index = indexUpperTriangular(k, i, i) + packedValues(index) = diagValues(i) + i += 1 + } + packedValues + } + /** * @return the index of the closest center to the given point, as well as the cost. */ def findClosest( - centers: TraversableOnce[VectorWithNorm], + centers: Array[VectorWithNorm], + statistics: Array[Double], + point: VectorWithNorm): (Int, Double) + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: Array[VectorWithNorm], point: VectorWithNorm): (Int, Double) = { var bestDistance = Double.PositiveInfinity var bestIndex = 0 var i = 0 - centers.foreach { center => + while (i < centers.length) { + val center = centers(i) val currentDistance = distance(center, point) if (currentDistance < bestDistance) { bestDistance = currentDistance @@ -48,7 +150,7 @@ private[spark] abstract class DistanceMeasure extends Serializable { * @return the K-means cost of a given point against the given cluster centers. */ def pointCost( - centers: TraversableOnce[VectorWithNorm], + centers: Array[VectorWithNorm], point: VectorWithNorm): Double = { findClosest(centers, point)._2 } @@ -154,22 +256,79 @@ object DistanceMeasure { } private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + + /** + * Statistics used in triangle inequality to obtain useful bounds to find closest centers. + * @see Charles Elkan, + * Using the Triangle Inequality to Accelerate k-Means + * + * @return One element used in statistics matrix to make matrix(i,j) represents: + * 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance + * computation. Given point x, let i be current closest center, and d be current best + * squared distance, if d < r, then we no longer need to compute the distance to center + * j. matrix(i,j) equals to squared of half of Euclidean distance between centers i + * and j; + * 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If squared + * distance between point x and center i is less than r, then center i is the closest + * center to point x. + */ + override def computeStatistics(distance: Double): Double = { + 0.25 * distance * distance + } + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + override def findClosest( + centers: Array[VectorWithNorm], + statistics: Array[Double], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point) + if (bestDistance < statistics(0)) return (0, bestDistance) + + val k = centers.length + var bestIndex = 0 + var i = 1 + while (i < k) { + val center = centers(i) + // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary + // distance computation. + val normDiff = center.norm - point.norm + val lowerBound = normDiff * normDiff + if (lowerBound < bestDistance) { + val index1 = indexUpperTriangular(k, i, bestIndex) + if (statistics(index1) < bestDistance) { + val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + val index2 = indexUpperTriangular(k, i, i) + if (d < statistics(index2)) return (i, d) + if (d < bestDistance) { + bestDistance = d + bestIndex = i + } + } + } + i += 1 + } + (bestIndex, bestDistance) + } + /** * @return the index of the closest center to the given point, as well as the squared distance. */ override def findClosest( - centers: TraversableOnce[VectorWithNorm], + centers: Array[VectorWithNorm], point: VectorWithNorm): (Int, Double) = { var bestDistance = Double.PositiveInfinity var bestIndex = 0 var i = 0 - centers.foreach { center => + while (i < centers.length) { + val center = centers(i) // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary // distance computation. var lowerBoundOfSqDist = center.norm - point.norm lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + val distance = EuclideanDistanceMeasure.fastSquaredDistance(center, point) if (distance < bestDistance) { bestDistance = distance bestIndex = i @@ -234,6 +393,58 @@ private[spark] object EuclideanDistanceMeasure { } private[spark] class CosineDistanceMeasure extends DistanceMeasure { + + /** + * Statistics used in triangle inequality to obtain useful bounds to find closest centers. + * + * @return One element used in statistics matrix to make matrix(i,j) represents: + * 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance + * computation. Given point x, let i be current closest center, and d be current best + * squared distance, if d < r, then we no longer need to compute the distance to center + * j. For Cosine distance, it is similar to Euclidean distance. However, radian/angle + * is used instead of Cosine distance to compute matrix(i,j): for centers i and j, + * compute the radian/angle between them, halving it, and converting it back to Cosine + * distance at the end; + * 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If Cosine + * distance between point x and center i is less than r, then center i is the closest + * center to point x. + */ + override def computeStatistics(distance: Double): Double = { + // d = 1 - cos(x) + // r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2) + 1 - math.sqrt(1 - distance / 2) + } + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: Array[VectorWithNorm], + statistics: Array[Double], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = distance(centers(0), point) + if (bestDistance < statistics(0)) return (0, bestDistance) + + val k = centers.length + var bestIndex = 0 + var i = 1 + while (i < k) { + val index1 = indexUpperTriangular(k, i, bestIndex) + if (statistics(index1) < bestDistance) { + val center = centers(i) + val d = distance(center, point) + val index2 = indexUpperTriangular(k, i, i) + if (d < statistics(index2)) return (i, d) + if (d < bestDistance) { + bestDistance = d + bestIndex = i + } + } + i += 1 + } + (bestIndex, bestDistance) + } + /** * @param v1: first vector * @param v2: second vector diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index a3cf7f9647..1c5de5a092 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -209,9 +209,7 @@ class KMeans private ( */ @Since("0.8.0") def run(data: RDD[Vector]): KMeansModel = { - val instances: RDD[(Vector, Double)] = data.map { - case (point) => (point, 1.0) - } + val instances = data.map(point => (point, 1.0)) runWithWeight(instances, None) } @@ -260,6 +258,7 @@ class KMeans private ( initKMeansParallel(data, distanceMeasureInstance) } } + val numFeatures = centers.head.vector.size val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.") @@ -269,34 +268,44 @@ class KMeans private ( val iterationStartTime = System.nanoTime() - instr.foreach(_.logNumFeatures(centers.head.vector.size)) + instr.foreach(_.logNumFeatures(numFeatures)) + + val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L // Execute iterations of Lloyd's algorithm until converged while (iteration < maxIterations && !converged) { - val costAccum = sc.doubleAccumulator val bcCenters = sc.broadcast(centers) + val stats = if (shouldDistributed) { + distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters) + } else { + distanceMeasureInstance.computeStatistics(centers) + } + val bcStats = sc.broadcast(stats) + + val costAccum = sc.doubleAccumulator // Find the new centers val collected = data.mapPartitions { points => - val thisCenters = bcCenters.value - val dims = thisCenters.head.vector.size + val centers = bcCenters.value + val stats = bcStats.value + val dims = centers.head.vector.size - val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) + val sums = Array.fill(centers.length)(Vectors.zeros(dims)) // clusterWeightSum is needed to calculate cluster center // cluster center = // sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ... - val clusterWeightSum = Array.ofDim[Double](thisCenters.length) + val clusterWeightSum = Array.ofDim[Double](centers.length) points.foreach { point => - val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) + val (bestCenter, cost) = distanceMeasureInstance.findClosest(centers, stats, point) costAccum.add(cost * point.weight) distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) clusterWeightSum(bestCenter) += point.weight } - clusterWeightSum.indices.filter(clusterWeightSum(_) > 0) - .map(j => (j, (sums(j), clusterWeightSum(j)))).iterator + Iterator.tabulate(centers.length)(j => (j, (sums(j), clusterWeightSum(j)))) + .filter(_._2._2 > 0) }.reduceByKey { (sumweight1, sumweight2) => axpy(1.0, sumweight2._1, sumweight1._1) (sumweight1._1, sumweight1._2 + sumweight2._2) @@ -307,15 +316,13 @@ class KMeans private ( instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum)) } - val newCenters = collected.mapValues { case (sum, weightSum) => - distanceMeasureInstance.centroid(sum, weightSum) - } - bcCenters.destroy() + bcStats.destroy() // Update the cluster centers and costs converged = true - newCenters.foreach { case (j, newCenter) => + collected.foreach { case (j, (sum, weightSum)) => + val newCenter = distanceMeasureInstance.centroid(sum, weightSum) if (converged && !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) { converged = false @@ -324,6 +331,7 @@ class KMeans private ( } cost = costAccum.value + instr.foreach(_.logNamedValue(s"Cost@iter=$iteration", s"$cost")) iteration += 1 } @@ -372,7 +380,7 @@ class KMeans private ( require(sample.nonEmpty, s"No samples available from $data") val centers = ArrayBuffer[VectorWithNorm]() - var newCenters = Seq(sample.head.toDense) + var newCenters = Array(sample.head.toDense) centers ++= newCenters // On each step, sample 2 * k points on average with probability proportional @@ -404,10 +412,10 @@ class KMeans private ( costs.unpersist() bcNewCentersList.foreach(_.destroy()) - val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)) + val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)).toArray - if (distinctCenters.size <= k) { - distinctCenters.toArray + if (distinctCenters.length <= k) { + distinctCenters } else { // Finally, we might have a set of more than k distinct candidate centers; weight each // candidate by the number of points in the dataset mapping to it and run a local k-means++ @@ -420,7 +428,7 @@ class KMeans private ( bcCenters.destroy() val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray - LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30) + LocalKMeans.kMeansPlusPlus(0, distinctCenters, myWeights, k, 30) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 0c6570ff81..04a3b6dd41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -48,6 +48,13 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], @transient private lazy val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + // TODO: computation of statistics may take seconds, so save it to KMeansModel in training + @transient private lazy val statistics = if (clusterCenters == null) { + null + } else { + distanceMeasureInstance.computeStatistics(clusterCentersWithNorm) + } + @Since("2.4.0") private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1) @@ -73,7 +80,8 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], */ @Since("0.8.0") def predict(point: Vector): Int = { - distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 + distanceMeasureInstance.findClosest(clusterCentersWithNorm, statistics, + new VectorWithNorm(point))._1 } /** @@ -82,8 +90,10 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) + val bcStatistics = points.context.broadcast(statistics) points.map(p => - distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) + distanceMeasureInstance.findClosest(bcCentersWithNorm.value, + bcStatistics.value, new VectorWithNorm(p))._1) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/DistanceMeasureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/DistanceMeasureSuite.scala new file mode 100644 index 0000000000..73691c4ecb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/DistanceMeasureSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class DistanceMeasureSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val seed = 42 + private val k = 10 + private val dim = 8 + + private var centers: Array[VectorWithNorm] = _ + + private var data: Array[VectorWithNorm] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val rng = new Random(seed) + + centers = Array.tabulate(k) { i => + val values = Array.fill(dim)(rng.nextGaussian) + new VectorWithNorm(Vectors.dense(values)) + } + + data = Array.tabulate(1000) { i => + val values = Array.fill(dim)(rng.nextGaussian) + new VectorWithNorm(Vectors.dense(values)) + } + } + + test("predict with statistics") { + Seq(DistanceMeasure.COSINE, DistanceMeasure.EUCLIDEAN).foreach { distanceMeasure => + val distance = DistanceMeasure.decodeFromString(distanceMeasure) + val statistics = distance.computeStatistics(centers) + data.foreach { point => + val (index1, cost1) = distance.findClosest(centers, point) + val (index2, cost2) = distance.findClosest(centers, statistics, point) + assert(index1 == index2) + assert(cost1 ~== cost2 relTol 1E-10) + } + } + } + + test("compute statistics distributedly") { + Seq(DistanceMeasure.COSINE, DistanceMeasure.EUCLIDEAN).foreach { distanceMeasure => + val distance = DistanceMeasure.decodeFromString(distanceMeasure) + val statistics1 = distance.computeStatistics(centers) + val sc = spark.sparkContext + val bcCenters = sc.broadcast(centers) + val statistics2 = distance.computeStatisticsDistributedly(sc, bcCenters) + bcCenters.destroy() + assert(Vectors.dense(statistics1) ~== Vectors.dense(statistics2) relTol 1E-10) + } + } +}