Streaming KMeans [MLLIB][SPARK-3254]
This adds a Streaming KMeans algorithm to MLlib. It uses an update rule that generalizes the mini-batch KMeans update to incorporate a decay factor, which allows past data to be forgotten. The decay factor can be specified explicitly, or via a more intuitive "fractional decay" setting, in units of either data points or batches. The PR includes: - StreamingKMeans algorithm with decay factor settings - Usage example - Additions to documentation clustering page - Unit tests of basic behavior and decay behaviors tdas mengxr rezazadeh Author: freeman <the.freeman.lab@gmail.com> Author: Jeremy Freeman <the.freeman.lab@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #2942 from freeman-lab/streaming-kmeans and squashes the following commits: b2e5b4a [freeman] Fixes to docs / examples 078617c [Jeremy Freeman] Merge pull request #1 from mengxr/SPARK-3254 2e682c0 [Xiangrui Meng] take discount on previous weights; use BLAS; detect dying clusters 0411bf5 [freeman] Change decay parameterization 9f7aea9 [freeman] Style fixes 374a706 [freeman] Formatting ad9bdc2 [freeman] Use labeled points and predictOnValues in examples 77dbd3f [freeman] Make initialization check an assertion 9cfc301 [freeman] Make random seed an argument 44050a9 [freeman] Simpler constructor c7050d5 [freeman] Fix spacing 2899623 [freeman] Use pattern matching for clarity a4a316b [freeman] Use collect 1472ec5 [freeman] Doc formatting ea22ec8 [freeman] Fix imports 2086bdc [freeman] Log cluster center updates ea9877c [freeman] More documentation 9facbe3 [freeman] Bug fix 5db7074 [freeman] Example usage for StreamingKMeans f33684b [freeman] Add explanation and example to docs b5b5f8d [freeman] Add better documentation a0fd790 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans 9fd9c15 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans b93350f [freeman] Streaming KMeans with decay
This commit is contained in:
parent
8602195510
commit
98c556ebbc
|
@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result).
|
|||
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
|
||||
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
|
||||
|
||||
## Examples
|
||||
### Examples
|
||||
|
||||
<div class="codetabs">
|
||||
<div data-lang="scala" markdown="1">
|
||||
|
@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap
|
|||
section of the Spark
|
||||
Quick Start guide. Be sure to also include *spark-mllib* to your build file as
|
||||
a dependency.
|
||||
|
||||
## Streaming clustering
|
||||
|
||||
When data arrive in a stream, we may want to estimate clusters dynamically,
|
||||
updating them as new data arrive. MLlib provides support for streaming k-means clustering,
|
||||
with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm
|
||||
uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign
|
||||
all points to their nearest cluster, compute new cluster centers, then update each cluster using:
|
||||
|
||||
`\begin{equation}
|
||||
c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
|
||||
\end{equation}`
|
||||
`\begin{equation}
|
||||
n_{t+1} = n_t + m_t
|
||||
\end{equation}`
|
||||
|
||||
Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned
|
||||
to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$`
|
||||
is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
|
||||
can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
|
||||
with `$\alpha$=0` only the most recent data will be used. This is analogous to an
|
||||
exponentially-weighted moving average.
|
||||
|
||||
The decay can be specified using a `halfLife` parameter, which determines the
|
||||
correct decay factor `a` such that, for data acquired
|
||||
at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
|
||||
The unit of time can be specified either as `batches` or `points` and the update rule
|
||||
will be adjusted accordingly.
|
||||
|
||||
### Examples
|
||||
|
||||
This example shows how to estimate clusters on streaming data.
|
||||
|
||||
<div class="codetabs">
|
||||
|
||||
<div data-lang="scala" markdown="1">
|
||||
|
||||
First we import the neccessary classes.
|
||||
|
||||
{% highlight scala %}
|
||||
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.clustering.StreamingKMeans
|
||||
|
||||
{% endhighlight %}
|
||||
|
||||
Then we make an input stream of vectors for training, as well as a stream of labeled data
|
||||
points for testing. We assume a StreamingContext `ssc` has been created, see
|
||||
[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
|
||||
|
||||
{% highlight scala %}
|
||||
|
||||
val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse)
|
||||
val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse)
|
||||
|
||||
{% endhighlight %}
|
||||
|
||||
We create a model with random clusters and specify the number of clusters to find
|
||||
|
||||
{% highlight scala %}
|
||||
|
||||
val numDimensions = 3
|
||||
val numClusters = 2
|
||||
val model = new StreamingKMeans()
|
||||
.setK(numClusters)
|
||||
.setDecayFactor(1.0)
|
||||
.setRandomCenters(numDimensions, 0.0)
|
||||
|
||||
{% endhighlight %}
|
||||
|
||||
Now register the streams for training and testing and start the job, printing
|
||||
the predicted cluster assignments on new data points as they arrive.
|
||||
|
||||
{% highlight scala %}
|
||||
|
||||
model.trainOn(trainingData)
|
||||
model.predictOnValues(testData).print()
|
||||
|
||||
ssc.start()
|
||||
ssc.awaitTermination()
|
||||
|
||||
{% endhighlight %}
|
||||
|
||||
As you add new text files with data the cluster centers will update. Each training
|
||||
point should be formatted as `[x1, x2, x3]`, and each test data point
|
||||
should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier
|
||||
(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir`
|
||||
the model will update. Anytime a text file is placed in `/testing/data/dir`
|
||||
you will see predictions. With new data, the cluster centers will change!
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.mllib
|
||||
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.clustering.StreamingKMeans
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.streaming.{Seconds, StreamingContext}
|
||||
|
||||
/**
|
||||
* Estimate clusters on one stream of data and make predictions
|
||||
* on another stream, where the data streams arrive as text files
|
||||
* into two different directories.
|
||||
*
|
||||
* The rows of the training text files must be vector data in the form
|
||||
* `[x1,x2,x3,...,xn]`
|
||||
* Where n is the number of dimensions.
|
||||
*
|
||||
* The rows of the test text files must be labeled data in the form
|
||||
* `(y,[x1,x2,x3,...,xn])`
|
||||
* Where y is some identifier. n must be the same for train and test.
|
||||
*
|
||||
* Usage: StreamingKmeans <trainingDir> <testDir> <batchDuration> <numClusters> <numDimensions>
|
||||
*
|
||||
* To run on your local machine using the two directories `trainingDir` and `testDir`,
|
||||
* with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call:
|
||||
* $ bin/run-example \
|
||||
* org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2
|
||||
*
|
||||
* As you add text files to `trainingDir` the clusters will continuously update.
|
||||
* Anytime you add text files to `testDir`, you'll see predicted labels using the current model.
|
||||
*
|
||||
*/
|
||||
object StreamingKMeans {
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length != 5) {
|
||||
System.err.println(
|
||||
"Usage: StreamingKMeans " +
|
||||
"<trainingDir> <testDir> <batchDuration> <numClusters> <numDimensions>")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
|
||||
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
|
||||
|
||||
val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
|
||||
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
|
||||
|
||||
val model = new StreamingKMeans()
|
||||
.setK(args(3).toInt)
|
||||
.setDecayFactor(1.0)
|
||||
.setRandomCenters(args(4).toInt, 0.0)
|
||||
|
||||
model.trainOn(trainingData)
|
||||
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
|
||||
|
||||
ssc.start()
|
||||
ssc.awaitTermination()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.clustering
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.SparkContext._
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.streaming.StreamingContext._
|
||||
import org.apache.spark.streaming.dstream.DStream
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* StreamingKMeansModel extends MLlib's KMeansModel for streaming
|
||||
* algorithms, so it can keep track of a continuously updated weight
|
||||
* associated with each cluster, and also update the model by
|
||||
* doing a single iteration of the standard k-means algorithm.
|
||||
*
|
||||
* The update algorithm uses the "mini-batch" KMeans rule,
|
||||
* generalized to incorporate forgetfullness (i.e. decay).
|
||||
* The update rule (for each cluster) is:
|
||||
*
|
||||
* c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
|
||||
* n_t+t = n_t * a + m_t
|
||||
*
|
||||
* Where c_t is the previously estimated centroid for that cluster,
|
||||
* n_t is the number of points assigned to it thus far, x_t is the centroid
|
||||
* estimated on the current batch, and m_t is the number of points assigned
|
||||
* to that centroid in the current batch.
|
||||
*
|
||||
* The decay factor 'a' scales the contribution of the clusters as estimated thus far,
|
||||
* by applying a as a discount weighting on the current point when evaluating
|
||||
* new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
|
||||
* are determined entirely by recent data. Lower values correspond to
|
||||
* more forgetting.
|
||||
*
|
||||
* Decay can optionally be specified by a half life and associated
|
||||
* time unit. The time unit can either be a batch of data or a single
|
||||
* data point. Considering data arrived at time t, the half life h is defined
|
||||
* such that at time t + h the discount applied to the data from t is 0.5.
|
||||
* The definition remains the same whether the time unit is given
|
||||
* as batches or points.
|
||||
*
|
||||
*/
|
||||
@DeveloperApi
|
||||
class StreamingKMeansModel(
|
||||
override val clusterCenters: Array[Vector],
|
||||
val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
|
||||
|
||||
/** Perform a k-means update on a batch of data. */
|
||||
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
|
||||
|
||||
// find nearest cluster to each point
|
||||
val closest = data.map(point => (this.predict(point), (point, 1L)))
|
||||
|
||||
// get sums and counts for updating each cluster
|
||||
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
|
||||
BLAS.axpy(1.0, p2._1, p1._1)
|
||||
(p1._1, p1._2 + p2._2)
|
||||
}
|
||||
val dim = clusterCenters(0).size
|
||||
val pointStats: Array[(Int, (Vector, Long))] = closest
|
||||
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
|
||||
.collect()
|
||||
|
||||
val discount = timeUnit match {
|
||||
case StreamingKMeans.BATCHES => decayFactor
|
||||
case StreamingKMeans.POINTS =>
|
||||
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
|
||||
n
|
||||
}.sum
|
||||
math.pow(decayFactor, numNewPoints)
|
||||
}
|
||||
|
||||
// apply discount to weights
|
||||
BLAS.scal(discount, Vectors.dense(clusterWeights))
|
||||
|
||||
// implement update rule
|
||||
pointStats.foreach { case (label, (sum, count)) =>
|
||||
val centroid = clusterCenters(label)
|
||||
|
||||
val updatedWeight = clusterWeights(label) + count
|
||||
val lambda = count / math.max(updatedWeight, 1e-16)
|
||||
|
||||
clusterWeights(label) = updatedWeight
|
||||
BLAS.scal(1.0 - lambda, centroid)
|
||||
BLAS.axpy(lambda / count, sum, centroid)
|
||||
|
||||
// display the updated cluster centers
|
||||
val display = clusterCenters(label).size match {
|
||||
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
|
||||
case _ => centroid.toArray.mkString("[", ",", "]")
|
||||
}
|
||||
|
||||
logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
|
||||
}
|
||||
|
||||
// Check whether the smallest cluster is dying. If so, split the largest cluster.
|
||||
val weightsWithIndex = clusterWeights.view.zipWithIndex
|
||||
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
|
||||
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
|
||||
if (minWeight < 1e-8 * maxWeight) {
|
||||
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
|
||||
val weight = (maxWeight + minWeight) / 2.0
|
||||
clusterWeights(largest) = weight
|
||||
clusterWeights(smallest) = weight
|
||||
val largestClusterCenter = clusterCenters(largest)
|
||||
val smallestClusterCenter = clusterCenters(smallest)
|
||||
var j = 0
|
||||
while (j < dim) {
|
||||
val x = largestClusterCenter(j)
|
||||
val p = 1e-14 * math.max(math.abs(x), 1.0)
|
||||
largestClusterCenter.toBreeze(j) = x + p
|
||||
smallestClusterCenter.toBreeze(j) = x - p
|
||||
j += 1
|
||||
}
|
||||
}
|
||||
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* StreamingKMeans provides methods for configuring a
|
||||
* streaming k-means analysis, training the model on streaming,
|
||||
* and using the model to make predictions on streaming data.
|
||||
* See KMeansModel for details on algorithm and update rules.
|
||||
*
|
||||
* Use a builder pattern to construct a streaming k-means analysis
|
||||
* in an application, like:
|
||||
*
|
||||
* val model = new StreamingKMeans()
|
||||
* .setDecayFactor(0.5)
|
||||
* .setK(3)
|
||||
* .setRandomCenters(5, 100.0)
|
||||
* .trainOn(DStream)
|
||||
*/
|
||||
@DeveloperApi
|
||||
class StreamingKMeans(
|
||||
var k: Int,
|
||||
var decayFactor: Double,
|
||||
var timeUnit: String) extends Logging {
|
||||
|
||||
def this() = this(2, 1.0, StreamingKMeans.BATCHES)
|
||||
|
||||
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
|
||||
|
||||
/** Set the number of clusters. */
|
||||
def setK(k: Int): this.type = {
|
||||
this.k = k
|
||||
this
|
||||
}
|
||||
|
||||
/** Set the decay factor directly (for forgetful algorithms). */
|
||||
def setDecayFactor(a: Double): this.type = {
|
||||
this.decayFactor = decayFactor
|
||||
this
|
||||
}
|
||||
|
||||
/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
|
||||
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
|
||||
if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
|
||||
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
|
||||
}
|
||||
this.decayFactor = math.exp(math.log(0.5) / halfLife)
|
||||
logInfo("Setting decay factor to: %g ".format (this.decayFactor))
|
||||
this.timeUnit = timeUnit
|
||||
this
|
||||
}
|
||||
|
||||
/** Specify initial centers directly. */
|
||||
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
|
||||
model = new StreamingKMeansModel(centers, weights)
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize random centers, requiring only the number of dimensions.
|
||||
*
|
||||
* @param dim Number of dimensions
|
||||
* @param weight Weight for each center
|
||||
* @param seed Random seed
|
||||
*/
|
||||
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
|
||||
val random = new XORShiftRandom(seed)
|
||||
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
|
||||
val weights = Array.fill(k)(weight)
|
||||
model = new StreamingKMeansModel(centers, weights)
|
||||
this
|
||||
}
|
||||
|
||||
/** Return the latest model. */
|
||||
def latestModel(): StreamingKMeansModel = {
|
||||
model
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the clustering model by training on batches of data from a DStream.
|
||||
* This operation registers a DStream for training the model,
|
||||
* checks whether the cluster centers have been initialized,
|
||||
* and updates the model using each batch of data from the stream.
|
||||
*
|
||||
* @param data DStream containing vector data
|
||||
*/
|
||||
def trainOn(data: DStream[Vector]) {
|
||||
assertInitialized()
|
||||
data.foreachRDD { (rdd, time) =>
|
||||
model = model.update(rdd, decayFactor, timeUnit)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the clustering model to make predictions on batches of data from a DStream.
|
||||
*
|
||||
* @param data DStream containing vector data
|
||||
* @return DStream containing predictions
|
||||
*/
|
||||
def predictOn(data: DStream[Vector]): DStream[Int] = {
|
||||
assertInitialized()
|
||||
data.map(model.predict)
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the model to make predictions on the values of a DStream and carry over its keys.
|
||||
*
|
||||
* @param data DStream containing (key, feature vector) pairs
|
||||
* @tparam K key type
|
||||
* @return DStream containing the input keys and the predictions as values
|
||||
*/
|
||||
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
|
||||
assertInitialized()
|
||||
data.mapValues(model.predict)
|
||||
}
|
||||
|
||||
/** Check whether cluster centers have been initialized. */
|
||||
private[this] def assertInitialized(): Unit = {
|
||||
if (model.clusterCenters == null) {
|
||||
throw new IllegalStateException(
|
||||
"Initial cluster centers must be set before starting predictions")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[clustering] object StreamingKMeans {
|
||||
final val BATCHES = "batches"
|
||||
final val POINTS = "points"
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.clustering
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.streaming.TestSuiteBase
|
||||
import org.apache.spark.streaming.dstream.DStream
|
||||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
||||
class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
|
||||
|
||||
override def maxWaitTimeMillis = 30000
|
||||
|
||||
test("accuracy for single center and equivalence to grand average") {
|
||||
// set parameters
|
||||
val numBatches = 10
|
||||
val numPoints = 50
|
||||
val k = 1
|
||||
val d = 5
|
||||
val r = 0.1
|
||||
|
||||
// create model with one cluster
|
||||
val model = new StreamingKMeans()
|
||||
.setK(1)
|
||||
.setDecayFactor(1.0)
|
||||
.setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0))
|
||||
|
||||
// generate random data for k-means
|
||||
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
|
||||
|
||||
// setup and run the model training
|
||||
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
|
||||
model.trainOn(inputDStream)
|
||||
inputDStream.count()
|
||||
})
|
||||
runStreams(ssc, numBatches, numBatches)
|
||||
|
||||
// estimated center should be close to true center
|
||||
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
|
||||
|
||||
// estimated center from streaming should exactly match the arithmetic mean of all data points
|
||||
// because the decay factor is set to 1.0
|
||||
val grandMean =
|
||||
input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
|
||||
assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
|
||||
}
|
||||
|
||||
test("accuracy for two centers") {
|
||||
val numBatches = 10
|
||||
val numPoints = 5
|
||||
val k = 2
|
||||
val d = 5
|
||||
val r = 0.1
|
||||
|
||||
// create model with two clusters
|
||||
val kMeans = new StreamingKMeans()
|
||||
.setK(2)
|
||||
.setHalfLife(2, "batches")
|
||||
.setInitialCenters(
|
||||
Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
|
||||
Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)),
|
||||
Array(5.0, 5.0))
|
||||
|
||||
// generate random data for k-means
|
||||
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
|
||||
|
||||
// setup and run the model training
|
||||
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
|
||||
kMeans.trainOn(inputDStream)
|
||||
inputDStream.count()
|
||||
})
|
||||
runStreams(ssc, numBatches, numBatches)
|
||||
|
||||
// check that estimated centers are close to true centers
|
||||
// NOTE exact assignment depends on the initialization!
|
||||
assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1)
|
||||
assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1)
|
||||
}
|
||||
|
||||
test("detecting dying clusters") {
|
||||
val numBatches = 10
|
||||
val numPoints = 5
|
||||
val k = 1
|
||||
val d = 1
|
||||
val r = 1.0
|
||||
|
||||
// create model with two clusters
|
||||
val kMeans = new StreamingKMeans()
|
||||
.setK(2)
|
||||
.setHalfLife(0.5, "points")
|
||||
.setInitialCenters(
|
||||
Array(Vectors.dense(0.0), Vectors.dense(1000.0)),
|
||||
Array(1.0, 1.0))
|
||||
|
||||
// new data are all around the first cluster 0.0
|
||||
val (input, _) =
|
||||
StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
|
||||
|
||||
// setup and run the model training
|
||||
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
|
||||
kMeans.trainOn(inputDStream)
|
||||
inputDStream.count()
|
||||
})
|
||||
runStreams(ssc, numBatches, numBatches)
|
||||
|
||||
// check that estimated centers are close to true centers
|
||||
// NOTE exact assignment depends on the initialization!
|
||||
val model = kMeans.latestModel()
|
||||
val c0 = model.clusterCenters(0)(0)
|
||||
val c1 = model.clusterCenters(1)(0)
|
||||
|
||||
assert(c0 * c1 < 0.0, "should have one positive center and one negative center")
|
||||
// 0.8 is the mean of half-normal distribution
|
||||
assert(math.abs(c0) ~== 0.8 absTol 0.6)
|
||||
assert(math.abs(c1) ~== 0.8 absTol 0.6)
|
||||
}
|
||||
|
||||
def StreamingKMeansDataGenerator(
|
||||
numPoints: Int,
|
||||
numBatches: Int,
|
||||
k: Int,
|
||||
d: Int,
|
||||
r: Double,
|
||||
seed: Int,
|
||||
initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
|
||||
val rand = new XORShiftRandom(seed)
|
||||
val centers = initCenters match {
|
||||
case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
|
||||
case _ => initCenters
|
||||
}
|
||||
val data = (0 until numBatches).map { i =>
|
||||
(0 until numPoints).map { idx =>
|
||||
val center = centers(idx % k)
|
||||
Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
|
||||
}
|
||||
}
|
||||
(data, centers)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue