[SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient

## What changes were proposed in this pull request?

mllib.KMeansModel.clusterCentersWithNorm is a method than ends up being called every time `predict` is called on a single vector, which is bad news for now the ml.KMeansModel Transformer works, which necessarily transforms one vector at a time.

This causes the model to just store the vectors with norms upfront. The extra norm should be small compared to the vectors. This would avoid this form of overhead on this and other code paths.

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #16328 from srowen/SPARK-18808.
This commit is contained in:
Sean Owen 2016-12-30 10:40:17 +00:00
parent 63036aee22
commit 56d3a7eb83
No known key found for this signature in database
GPG key ID: BEB3956D6717BDDC
2 changed files with 9 additions and 10 deletions

View file

@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
extends Saveable with Serializable with PMMLExportable { extends Saveable with Serializable with PMMLExportable {
private val clusterCentersWithNorm =
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
/** /**
* A Java-friendly constructor that takes an Iterable of Vectors. * A Java-friendly constructor that takes an Iterable of Vectors.
*/ */
@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
* Total number of clusters. * Total number of clusters.
*/ */
@Since("0.8.0") @Since("0.8.0")
def k: Int = clusterCenters.length def k: Int = clusterCentersWithNorm.length
/** /**
* Returns the cluster index that a given point belongs to. * Returns the cluster index that a given point belongs to.
@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/ */
@Since("1.0.0") @Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = { def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
} }
@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/ */
@Since("0.8.0") @Since("0.8.0")
def computeCost(data: RDD[Vector]): Double = { def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
} }
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))
@Since("1.4.0") @Since("1.4.0")
override def save(sc: SparkContext, path: String): Unit = { override def save(sc: SparkContext, path: String): Unit = {
@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
val metadata = compact(render( val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
Cluster(id, point) Cluster(id, p.vector)
} }
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
} }

View file

@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
} }
} }
this new StreamingKMeansModel(clusterCenters, clusterWeights)
} }
} }