[SPARK-3424][MLLIB] cache point distances during k-means|| init
This PR ports the following feature implemented in #2634 by derrickburns:
* During k-means|| initialization, we should cache costs (squared distances) previously computed.
It also contains the following optimization:
* aggregate sumCosts directly
* ran multiple (#runs) k-means++ in parallel
I compared the performance locally on mnist-digit. Before this patch:
![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png)
with this patch:
![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png)
It is clear that each k-means|| iteration takes about the same amount of time with this patch.
Authors:
Derrick Burns <derrickburns@gmail.com>
Xiangrui Meng <meng@databricks.com>
Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits:
0a875ec
[Xiangrui Meng] address comments
4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means||
This commit is contained in:
parent
27bccc5ea9
commit
ca7910d6dd
|
@ -279,45 +279,80 @@ class KMeans private (
|
|||
*/
|
||||
private def initKMeansParallel(data: RDD[VectorWithNorm])
|
||||
: Array[Array[VectorWithNorm]] = {
|
||||
// Initialize each run's center to a random point
|
||||
// Initialize empty centers and point costs.
|
||||
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
|
||||
var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
|
||||
|
||||
// Initialize each run's first center to a random point.
|
||||
val seed = new XORShiftRandom(this.seed).nextInt()
|
||||
val sample = data.takeSample(true, runs, seed).toSeq
|
||||
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
|
||||
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
|
||||
|
||||
/** Merges new centers to centers. */
|
||||
def mergeNewCenters(): Unit = {
|
||||
var r = 0
|
||||
while (r < runs) {
|
||||
centers(r) ++= newCenters(r)
|
||||
newCenters(r).clear()
|
||||
r += 1
|
||||
}
|
||||
}
|
||||
|
||||
// On each step, sample 2 * k points on average for each run with probability proportional
|
||||
// to their squared distance from that run's current centers
|
||||
// to their squared distance from that run's centers. Note that only distances between points
|
||||
// and new centers are computed in each iteration.
|
||||
var step = 0
|
||||
while (step < initializationSteps) {
|
||||
val bcCenters = data.context.broadcast(centers)
|
||||
val sumCosts = data.flatMap { point =>
|
||||
(0 until runs).map { r =>
|
||||
(r, KMeans.pointCost(bcCenters.value(r), point))
|
||||
val bcNewCenters = data.context.broadcast(newCenters)
|
||||
val preCosts = costs
|
||||
costs = data.zip(preCosts).map { case (point, cost) =>
|
||||
Vectors.dense(
|
||||
Array.tabulate(runs) { r =>
|
||||
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
|
||||
})
|
||||
}.cache()
|
||||
val sumCosts = costs
|
||||
.aggregate(Vectors.zeros(runs))(
|
||||
seqOp = (s, v) => {
|
||||
// s += v
|
||||
axpy(1.0, v, s)
|
||||
s
|
||||
},
|
||||
combOp = (s0, s1) => {
|
||||
// s0 += s1
|
||||
axpy(1.0, s1, s0)
|
||||
s0
|
||||
}
|
||||
}.reduceByKey(_ + _).collectAsMap()
|
||||
val chosen = data.mapPartitionsWithIndex { (index, points) =>
|
||||
)
|
||||
preCosts.unpersist(blocking = false)
|
||||
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
|
||||
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
|
||||
points.flatMap { p =>
|
||||
pointsWithCosts.flatMap { case (p, c) =>
|
||||
(0 until runs).filter { r =>
|
||||
rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
|
||||
rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
|
||||
}.map((_, p))
|
||||
}
|
||||
}.collect()
|
||||
mergeNewCenters()
|
||||
chosen.foreach { case (r, p) =>
|
||||
centers(r) += p.toDense
|
||||
newCenters(r) += p.toDense
|
||||
}
|
||||
step += 1
|
||||
}
|
||||
|
||||
mergeNewCenters()
|
||||
costs.unpersist(blocking = false)
|
||||
|
||||
// Finally, we might have a set of more than k candidate centers for each run; weigh each
|
||||
// candidate by the number of points in the dataset mapping to it and run a local k-means++
|
||||
// on the weighted centers to pick just k of them
|
||||
val bcCenters = data.context.broadcast(centers)
|
||||
val weightMap = data.flatMap { p =>
|
||||
(0 until runs).map { r =>
|
||||
Iterator.tabulate(runs) { r =>
|
||||
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
|
||||
}
|
||||
}.reduceByKey(_ + _).collectAsMap()
|
||||
val finalCenters = (0 until runs).map { r =>
|
||||
val finalCenters = (0 until runs).par.map { r =>
|
||||
val myCenters = centers(r).toArray
|
||||
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
|
||||
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
|
||||
|
|
Loading…
Reference in a new issue