[SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient

## What changes were proposed in this pull request?

mllib.KMeansModel.clusterCentersWithNorm is a method than ends up being called every time `predict` is called on a single vector, which is bad news for now the ml.KMeansModel Transformer works, which necessarily transforms one vector at a time.

This causes the model to just store the vectors with norms upfront. The extra norm should be small compared to the vectors. This would avoid this form of overhead on this and other code paths.

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #16328 from srowen/SPARK-18808.
This commit is contained in:
Sean Owen 2016-12-30 10:40:17 +00:00
parent 63036aee22
commit 56d3a7eb83
No known key found for this signature in database
GPG key ID: BEB3956D6717BDDC
2 changed files with 9 additions and 10 deletions

View file

@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
extends Saveable with Serializable with PMMLExportable {
private val clusterCentersWithNorm =
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
/**
* A Java-friendly constructor that takes an Iterable of Vectors.
*/
@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
* Total number of clusters.
*/
@Since("0.8.0")
def k: Int = clusterCenters.length
def k: Int = clusterCentersWithNorm.length
/**
* Returns the cluster index that a given point belongs to.
@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
}
@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("0.8.0")
def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
}
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))
@Since("1.4.0")
override def save(sc: SparkContext, path: String): Unit = {
@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
Cluster(id, point)
val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
Cluster(id, p.vector)
}
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}

View file

@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
}
}
this
new StreamingKMeansModel(clusterCenters, clusterWeights)
}
}