[SPARK-23412][ML] Add cosine distance to BisectingKMeans
## What changes were proposed in this pull request? The PR adds the option to specify a distance measure in BisectingKMeans. Moreover, it introduces the ability to use the cosine distance measure in it. ## How was this patch tested? added UTs + existing UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes #20600 from mgaido91/SPARK-23412.
This commit is contained in:
parent
d5b41aea62
commit
567bd31e0a
|
@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
|||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
|
||||
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
|
||||
BisectingKMeansModel => MLlibBisectingKMeansModel}
|
||||
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType}
|
|||
/**
|
||||
* Common params for BisectingKMeans and BisectingKMeansModel
|
||||
*/
|
||||
private[clustering] trait BisectingKMeansParams extends Params
|
||||
with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
|
||||
private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
|
||||
with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure {
|
||||
|
||||
/**
|
||||
* The desired number of leaf clusters. Must be > 1. Default: 4.
|
||||
|
@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] (
|
|||
@Since("2.1.0")
|
||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/** @group expertSetParam */
|
||||
@Since("2.4.0")
|
||||
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
|
||||
|
||||
@Since("2.0.0")
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") (
|
|||
@Since("2.0.0")
|
||||
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
|
||||
|
||||
/** @group expertSetParam */
|
||||
@Since("2.4.0")
|
||||
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
|
||||
|
||||
@Since("2.0.0")
|
||||
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") (
|
|||
.setMaxIterations($(maxIter))
|
||||
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
|
||||
.setSeed($(seed))
|
||||
.setDistanceMeasure($(distanceMeasure))
|
||||
val parentModel = bkm.run(rdd)
|
||||
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
|
||||
val summary = new BisectingKMeansSummary(
|
||||
|
|
|
@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
|
|||
* Common params for KMeans and KMeansModel
|
||||
*/
|
||||
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
|
||||
with HasSeed with HasPredictionCol with HasTol {
|
||||
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {
|
||||
|
||||
/**
|
||||
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
|
||||
|
@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
|
|||
@Since("1.5.0")
|
||||
def getInitMode: String = $(initMode)
|
||||
|
||||
@Since("2.4.0")
|
||||
final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
|
||||
"Supported options: 'euclidean' and 'cosine'.",
|
||||
(value: String) => MLlibKMeans.validateDistanceMeasure(value))
|
||||
|
||||
/** @group expertGetParam */
|
||||
@Since("2.4.0")
|
||||
def getDistanceMeasure: String = $(distanceMeasure)
|
||||
|
||||
/**
|
||||
* Param for the number of steps for the k-means|| initialization mode. This is an advanced
|
||||
* setting -- the default of 2 is almost always enough. Must be > 0. Default: 2.
|
||||
|
|
|
@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen {
|
|||
"after fitting. If set to true, then all sub-models will be available. Warning: For " +
|
||||
"large models, collecting all sub-models can cause OOMs on the Spark driver",
|
||||
Some("false"), isExpertParam = true),
|
||||
ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false)
|
||||
ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false),
|
||||
ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" +
|
||||
" and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"),
|
||||
isValid = "(value: String) => " +
|
||||
"org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)")
|
||||
)
|
||||
|
||||
val code = genSharedParams(params)
|
||||
|
|
|
@ -504,4 +504,23 @@ trait HasLoss extends Params {
|
|||
/** @group getParam */
|
||||
final def getLoss: String = $(loss)
|
||||
}
|
||||
|
||||
/**
|
||||
* Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or
|
||||
* removed between minor versions.
|
||||
*/
|
||||
@DeveloperApi
|
||||
trait HasDistanceMeasure extends Params {
|
||||
|
||||
/**
|
||||
* Param for The distance measure. Supported options: 'euclidean' and 'cosine'.
|
||||
* @group param
|
||||
*/
|
||||
final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value))
|
||||
|
||||
setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN)
|
||||
|
||||
/** @group getParam */
|
||||
final def getDistanceMeasure: String = $(distanceMeasure)
|
||||
}
|
||||
// scalastyle:on
|
||||
|
|
|
@ -25,7 +25,7 @@ import scala.collection.mutable
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
@ -57,7 +57,8 @@ class BisectingKMeans private (
|
|||
private var k: Int,
|
||||
private var maxIterations: Int,
|
||||
private var minDivisibleClusterSize: Double,
|
||||
private var seed: Long) extends Logging {
|
||||
private var seed: Long,
|
||||
private var distanceMeasure: String) extends Logging {
|
||||
|
||||
import BisectingKMeans._
|
||||
|
||||
|
@ -65,7 +66,7 @@ class BisectingKMeans private (
|
|||
* Constructs with the default configuration
|
||||
*/
|
||||
@Since("1.6.0")
|
||||
def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##)
|
||||
def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN)
|
||||
|
||||
/**
|
||||
* Sets the desired number of leaf clusters (default: 4).
|
||||
|
@ -134,6 +135,22 @@ class BisectingKMeans private (
|
|||
@Since("1.6.0")
|
||||
def getSeed: Long = this.seed
|
||||
|
||||
/**
|
||||
* The distance suite used by the algorithm.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
def getDistanceMeasure: String = distanceMeasure
|
||||
|
||||
/**
|
||||
* Set the distance suite used by the algorithm.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
def setDistanceMeasure(distanceMeasure: String): this.type = {
|
||||
DistanceMeasure.validateDistanceMeasure(distanceMeasure)
|
||||
this.distanceMeasure = distanceMeasure
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the bisecting k-means algorithm.
|
||||
* @param input RDD of vectors
|
||||
|
@ -147,11 +164,13 @@ class BisectingKMeans private (
|
|||
}
|
||||
val d = input.map(_.size).first()
|
||||
logInfo(s"Feature dimension: $d.")
|
||||
|
||||
val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
|
||||
// Compute and cache vector norms for fast distance computation.
|
||||
val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK)
|
||||
val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) }
|
||||
var assignments = vectors.map(v => (ROOT_INDEX, v))
|
||||
var activeClusters = summarize(d, assignments)
|
||||
var activeClusters = summarize(d, assignments, dMeasure)
|
||||
val rootSummary = activeClusters(ROOT_INDEX)
|
||||
val n = rootSummary.size
|
||||
logInfo(s"Number of points: $n.")
|
||||
|
@ -184,24 +203,25 @@ class BisectingKMeans private (
|
|||
val divisibleIndices = divisibleClusters.keys.toSet
|
||||
logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.")
|
||||
var newClusterCenters = divisibleClusters.flatMap { case (index, summary) =>
|
||||
val (left, right) = splitCenter(summary.center, random)
|
||||
val (left, right) = splitCenter(summary.center, random, dMeasure)
|
||||
Iterator((leftChildIndex(index), left), (rightChildIndex(index), right))
|
||||
}.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map
|
||||
var newClusters: Map[Long, ClusterSummary] = null
|
||||
var newAssignments: RDD[(Long, VectorWithNorm)] = null
|
||||
for (iter <- 0 until maxIterations) {
|
||||
newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters)
|
||||
newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters,
|
||||
dMeasure)
|
||||
.filter { case (index, _) =>
|
||||
divisibleIndices.contains(parentIndex(index))
|
||||
}
|
||||
newClusters = summarize(d, newAssignments)
|
||||
newClusters = summarize(d, newAssignments, dMeasure)
|
||||
newClusterCenters = newClusters.mapValues(_.center).map(identity)
|
||||
}
|
||||
if (preIndices != null) {
|
||||
preIndices.unpersist(false)
|
||||
}
|
||||
preIndices = indices
|
||||
indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
|
||||
indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys
|
||||
.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
assignments = indices.zip(vectors)
|
||||
inactiveClusters ++= activeClusters
|
||||
|
@ -222,8 +242,8 @@ class BisectingKMeans private (
|
|||
}
|
||||
norms.unpersist(false)
|
||||
val clusters = activeClusters ++ inactiveClusters
|
||||
val root = buildTree(clusters)
|
||||
new BisectingKMeansModel(root)
|
||||
val root = buildTree(clusters, dMeasure)
|
||||
new BisectingKMeansModel(root, this.distanceMeasure)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable {
|
|||
*/
|
||||
private def summarize(
|
||||
d: Int,
|
||||
assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = {
|
||||
assignments.aggregateByKey(new ClusterSummaryAggregator(d))(
|
||||
assignments: RDD[(Long, VectorWithNorm)],
|
||||
distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = {
|
||||
assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))(
|
||||
seqOp = (agg, v) => agg.add(v),
|
||||
combOp = (agg1, agg2) => agg1.merge(agg2)
|
||||
).mapValues(_.summary)
|
||||
|
@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable {
|
|||
* Cluster summary aggregator.
|
||||
* @param d feature dimension
|
||||
*/
|
||||
private class ClusterSummaryAggregator(val d: Int) extends Serializable {
|
||||
private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure)
|
||||
extends Serializable {
|
||||
private var n: Long = 0L
|
||||
private val sum: Vector = Vectors.zeros(d)
|
||||
private var sumSq: Double = 0.0
|
||||
|
@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable {
|
|||
n += 1L
|
||||
// TODO: use a numerically stable approach to estimate cost
|
||||
sumSq += v.norm * v.norm
|
||||
BLAS.axpy(1.0, v.vector, sum)
|
||||
distanceMeasure.updateClusterSum(v, sum)
|
||||
this
|
||||
}
|
||||
|
||||
|
@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable {
|
|||
def merge(other: ClusterSummaryAggregator): this.type = {
|
||||
n += other.n
|
||||
sumSq += other.sumSq
|
||||
BLAS.axpy(1.0, other.sum, sum)
|
||||
distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum)
|
||||
this
|
||||
}
|
||||
|
||||
/** Returns the summary. */
|
||||
def summary: ClusterSummary = {
|
||||
val mean = sum.copy
|
||||
if (n > 0L) {
|
||||
BLAS.scal(1.0 / n, mean)
|
||||
}
|
||||
val center = new VectorWithNorm(mean)
|
||||
val cost = math.max(sumSq - n * center.norm * center.norm, 0.0)
|
||||
new ClusterSummary(n, center, cost)
|
||||
val center = distanceMeasure.centroid(sum.copy, n)
|
||||
val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq)
|
||||
ClusterSummary(n, center, cost)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable {
|
|||
*/
|
||||
private def splitCenter(
|
||||
center: VectorWithNorm,
|
||||
random: Random): (VectorWithNorm, VectorWithNorm) = {
|
||||
random: Random,
|
||||
distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = {
|
||||
val d = center.vector.size
|
||||
val norm = center.norm
|
||||
val level = 1e-4 * norm
|
||||
val noise = Vectors.dense(Array.fill(d)(random.nextDouble()))
|
||||
val left = center.vector.copy
|
||||
BLAS.axpy(-level, noise, left)
|
||||
val right = center.vector.copy
|
||||
BLAS.axpy(level, noise, right)
|
||||
(new VectorWithNorm(left), new VectorWithNorm(right))
|
||||
distanceMeasure.symmetricCentroids(level, noise, center.vector)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable {
|
|||
private def updateAssignments(
|
||||
assignments: RDD[(Long, VectorWithNorm)],
|
||||
divisibleIndices: Set[Long],
|
||||
newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = {
|
||||
newClusterCenters: Map[Long, VectorWithNorm],
|
||||
distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = {
|
||||
assignments.map { case (index, v) =>
|
||||
if (divisibleIndices.contains(index)) {
|
||||
val children = Seq(leftChildIndex(index), rightChildIndex(index))
|
||||
val newClusterChildren = children.filter(newClusterCenters.contains(_))
|
||||
val newClusterChildren = children.filter(newClusterCenters.contains)
|
||||
val newClusterChildrenCenterToId =
|
||||
newClusterChildren.map(id => newClusterCenters(id) -> id).toMap
|
||||
val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray
|
||||
if (newClusterChildren.nonEmpty) {
|
||||
val selected = newClusterChildren.minBy { child =>
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v)
|
||||
}
|
||||
(selected, v)
|
||||
val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1
|
||||
val center = newClusterChildrenCenters(selected)
|
||||
val id = newClusterChildrenCenterToId(center)
|
||||
(id, v)
|
||||
} else {
|
||||
(index, v)
|
||||
}
|
||||
|
@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable {
|
|||
* @param clusters a map from cluster indices to corresponding cluster summaries
|
||||
* @return the root node of the clustering tree
|
||||
*/
|
||||
private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = {
|
||||
private def buildTree(
|
||||
clusters: Map[Long, ClusterSummary],
|
||||
distanceMeasure: DistanceMeasure): ClusteringTreeNode = {
|
||||
var leafIndex = 0
|
||||
var internalIndex = -1
|
||||
|
||||
|
@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable {
|
|||
internalIndex -= 1
|
||||
val leftIndex = leftChildIndex(rawIndex)
|
||||
val rightIndex = rightChildIndex(rawIndex)
|
||||
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
|
||||
val height = math.sqrt(indexes.map { childIndex =>
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center)
|
||||
}.max)
|
||||
val children = indexes.map(buildSubTree(_)).toArray
|
||||
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains)
|
||||
val height = indexes.map { childIndex =>
|
||||
distanceMeasure.distance(center, clusters(childIndex).center)
|
||||
}.max
|
||||
val children = indexes.map(buildSubTree).toArray
|
||||
new ClusteringTreeNode(index, size, center, cost, height, children)
|
||||
} else {
|
||||
val index = leafIndex
|
||||
|
@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] (
|
|||
def center: Vector = centerWithNorm.vector
|
||||
|
||||
/** Predicts the leaf cluster node index that the input point belongs to. */
|
||||
def predict(point: Vector): Int = {
|
||||
val (index, _) = predict(new VectorWithNorm(point))
|
||||
def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = {
|
||||
val (index, _) = predict(new VectorWithNorm(point), distanceMeasure)
|
||||
index
|
||||
}
|
||||
|
||||
/** Returns the full prediction path from root to leaf. */
|
||||
def predictPath(point: Vector): Array[ClusteringTreeNode] = {
|
||||
predictPath(new VectorWithNorm(point)).toArray
|
||||
def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = {
|
||||
predictPath(new VectorWithNorm(point), distanceMeasure).toArray
|
||||
}
|
||||
|
||||
/** Returns the full prediction path from root to leaf. */
|
||||
private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = {
|
||||
private def predictPath(
|
||||
pointWithNorm: VectorWithNorm,
|
||||
distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = {
|
||||
if (isLeaf) {
|
||||
this :: Nil
|
||||
} else {
|
||||
val selected = children.minBy { child =>
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
|
||||
distanceMeasure.distance(child.centerWithNorm, pointWithNorm)
|
||||
}
|
||||
selected :: selected.predictPath(pointWithNorm)
|
||||
selected :: selected.predictPath(pointWithNorm, distanceMeasure)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the cost (squared distance to the predicted leaf cluster center) of the input point.
|
||||
* Computes the cost of the input point.
|
||||
*/
|
||||
def computeCost(point: Vector): Double = {
|
||||
val (_, cost) = predict(new VectorWithNorm(point))
|
||||
def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = {
|
||||
val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure)
|
||||
cost
|
||||
}
|
||||
|
||||
/**
|
||||
* Predicts the cluster index and the cost of the input point.
|
||||
*/
|
||||
private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
|
||||
predict(pointWithNorm,
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm))
|
||||
private def predict(
|
||||
pointWithNorm: VectorWithNorm,
|
||||
distanceMeasure: DistanceMeasure): (Int, Double) = {
|
||||
predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] (
|
|||
* @return (predicted leaf cluster index, cost)
|
||||
*/
|
||||
@tailrec
|
||||
private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = {
|
||||
private def predict(
|
||||
pointWithNorm: VectorWithNorm,
|
||||
cost: Double,
|
||||
distanceMeasure: DistanceMeasure): (Int, Double) = {
|
||||
if (isLeaf) {
|
||||
(index, cost)
|
||||
} else {
|
||||
val (selectedChild, minCost) = children.map { child =>
|
||||
(child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
|
||||
(child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm))
|
||||
}.minBy(_._2)
|
||||
selectedChild.predict(pointWithNorm, minCost)
|
||||
selectedChild.predict(pointWithNorm, minCost, distanceMeasure)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession}
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
class BisectingKMeansModel private[clustering] (
|
||||
private[clustering] val root: ClusteringTreeNode
|
||||
private[clustering] val root: ClusteringTreeNode,
|
||||
@Since("2.4.0") val distanceMeasure: String
|
||||
) extends Serializable with Saveable with Logging {
|
||||
|
||||
@Since("1.6.0")
|
||||
def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN)
|
||||
|
||||
private val distanceMeasureInstance: DistanceMeasure =
|
||||
DistanceMeasure.decodeFromString(distanceMeasure)
|
||||
|
||||
/**
|
||||
* Leaf cluster centers.
|
||||
*/
|
||||
|
@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] (
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
def predict(point: Vector): Int = {
|
||||
root.predict(point)
|
||||
root.predict(point, distanceMeasureInstance)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] (
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
def predict(points: RDD[Vector]): RDD[Int] = {
|
||||
points.map { p => root.predict(p) }
|
||||
points.map { p => root.predict(p, distanceMeasureInstance) }
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] (
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
def computeCost(point: Vector): Double = {
|
||||
root.computeCost(point)
|
||||
root.computeCost(point, distanceMeasureInstance)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] (
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
def computeCost(data: RDD[Vector]): Double = {
|
||||
data.map(root.computeCost).sum()
|
||||
data.map(root.computeCost(_, distanceMeasureInstance)).sum()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
|
|||
|
||||
@Since("2.0.0")
|
||||
override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
|
||||
val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
implicit val formats = DefaultFormats
|
||||
val rootId = (metadata \ "rootId").extract[Int]
|
||||
val classNameV1_0 = SaveLoadV1_0.thisClassName
|
||||
val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path)
|
||||
(loadedClassName, formatVersion) match {
|
||||
case (classNameV1_0, "1.0") =>
|
||||
val model = SaveLoadV1_0.load(sc, path, rootId)
|
||||
case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) =>
|
||||
val model = SaveLoadV1_0.load(sc, path)
|
||||
model
|
||||
case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) =>
|
||||
val model = SaveLoadV1_0.load(sc, path)
|
||||
model
|
||||
case _ => throw new Exception(
|
||||
s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
|
||||
s"($loadedClassName, $formatVersion). Supported:\n" +
|
||||
s" ($classNameV1_0, 1.0)")
|
||||
s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" +
|
||||
s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
|
|||
r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
|
||||
}
|
||||
|
||||
private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
|
||||
if (node.children.isEmpty) {
|
||||
Array(node)
|
||||
} else {
|
||||
node.children.flatMap(getNodes) ++ Array(node)
|
||||
}
|
||||
}
|
||||
|
||||
private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
|
||||
val root = nodes(rootId)
|
||||
if (root.children.isEmpty) {
|
||||
new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
|
||||
root.cost, root.height, new Array[ClusteringTreeNode](0))
|
||||
} else {
|
||||
val children = root.children.map(c => buildTree(c, nodes))
|
||||
new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
|
||||
root.cost, root.height, children.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
private[clustering] object SaveLoadV1_0 {
|
||||
private val thisFormatVersion = "1.0"
|
||||
private[clustering] val thisFormatVersion = "1.0"
|
||||
|
||||
private[clustering]
|
||||
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
|
||||
|
@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
|
|||
spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
|
||||
}
|
||||
|
||||
private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
|
||||
if (node.children.isEmpty) {
|
||||
Array(node)
|
||||
} else {
|
||||
node.children.flatMap(getNodes(_)) ++ Array(node)
|
||||
}
|
||||
}
|
||||
|
||||
def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
|
||||
def load(sc: SparkContext, path: String): BisectingKMeansModel = {
|
||||
implicit val formats: DefaultFormats = DefaultFormats
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
val rootId = (metadata \ "rootId").extract[Int]
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val rows = spark.read.parquet(Loader.dataPath(path))
|
||||
Loader.checkSchema[Data](rows.schema)
|
||||
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
|
||||
val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
|
||||
val rootNode = buildTree(rootId, nodes)
|
||||
new BisectingKMeansModel(rootNode)
|
||||
new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN)
|
||||
}
|
||||
}
|
||||
|
||||
private[clustering] object SaveLoadV2_0 {
|
||||
private[clustering] val thisFormatVersion = "2.0"
|
||||
|
||||
private[clustering]
|
||||
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
|
||||
|
||||
def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
|
||||
~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure)))
|
||||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
||||
|
||||
val data = getNodes(model.root).map(node => Data(node.index, node.size,
|
||||
node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
|
||||
node.children.map(_.index)))
|
||||
spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
|
||||
}
|
||||
|
||||
private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
|
||||
val root = nodes.get(rootId).get
|
||||
if (root.children.isEmpty) {
|
||||
new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
|
||||
root.cost, root.height, new Array[ClusteringTreeNode](0))
|
||||
} else {
|
||||
val children = root.children.map(c => buildTree(c, nodes))
|
||||
new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
|
||||
root.cost, root.height, children.toArray)
|
||||
}
|
||||
def load(sc: SparkContext, path: String): BisectingKMeansModel = {
|
||||
implicit val formats: DefaultFormats = DefaultFormats
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
val rootId = (metadata \ "rootId").extract[Int]
|
||||
val distanceMeasure = (metadata \ "distanceMeasure").extract[String]
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val rows = spark.read.parquet(Loader.dataPath(path))
|
||||
Loader.checkSchema[Data](rows.schema)
|
||||
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
|
||||
val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
|
||||
val rootNode = buildTree(rootId, nodes)
|
||||
new BisectingKMeansModel(rootNode, distanceMeasure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.clustering
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
||||
private[spark] abstract class DistanceMeasure extends Serializable {
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the cost.
|
||||
*/
|
||||
def findClosest(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = Double.PositiveInfinity
|
||||
var bestIndex = 0
|
||||
var i = 0
|
||||
centers.foreach { center =>
|
||||
val currentDistance = distance(center, point)
|
||||
if (currentDistance < bestDistance) {
|
||||
bestDistance = currentDistance
|
||||
bestIndex = i
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(bestIndex, bestDistance)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the K-means cost of a given point against the given cluster centers.
|
||||
*/
|
||||
def pointCost(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
point: VectorWithNorm): Double = {
|
||||
findClosest(centers, point)._2
|
||||
}
|
||||
|
||||
/**
|
||||
* @return whether a center converged or not, given the epsilon parameter.
|
||||
*/
|
||||
def isCenterConverged(
|
||||
oldCenter: VectorWithNorm,
|
||||
newCenter: VectorWithNorm,
|
||||
epsilon: Double): Boolean = {
|
||||
distance(oldCenter, newCenter) <= epsilon
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the distance between two points.
|
||||
*/
|
||||
def distance(
|
||||
v1: VectorWithNorm,
|
||||
v2: VectorWithNorm): Double
|
||||
|
||||
/**
|
||||
* @return the total cost of the cluster from its aggregated properties
|
||||
*/
|
||||
def clusterCost(
|
||||
centroid: VectorWithNorm,
|
||||
pointsSum: VectorWithNorm,
|
||||
numberOfPoints: Long,
|
||||
pointsSquaredNorm: Double): Double
|
||||
|
||||
/**
|
||||
* Updates the value of `sum` adding the `point` vector.
|
||||
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
|
||||
* @param sum the `sum` for a cluster to be updated
|
||||
*/
|
||||
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
|
||||
axpy(1.0, point.vector, sum)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
|
||||
*
|
||||
* @param sum the `sum` for a cluster
|
||||
* @param count the number of points in the cluster
|
||||
* @return the centroid of the cluster
|
||||
*/
|
||||
def centroid(sum: Vector, count: Long): VectorWithNorm = {
|
||||
scal(1.0 / count, sum)
|
||||
new VectorWithNorm(sum)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns two new centroids symmetric to the specified centroid applying `noise` with the
|
||||
* with the specified `level`.
|
||||
*
|
||||
* @param level the level of `noise` to apply to the given centroid.
|
||||
* @param noise a noise vector
|
||||
* @param centroid the parent centroid
|
||||
* @return a left and right centroid symmetric to `centroid`
|
||||
*/
|
||||
def symmetricCentroids(
|
||||
level: Double,
|
||||
noise: Vector,
|
||||
centroid: Vector): (VectorWithNorm, VectorWithNorm) = {
|
||||
val left = centroid.copy
|
||||
axpy(-level, noise, left)
|
||||
val right = centroid.copy
|
||||
axpy(level, noise, right)
|
||||
(new VectorWithNorm(left), new VectorWithNorm(right))
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the cost of a point to be assigned to the cluster centroid
|
||||
*/
|
||||
def cost(
|
||||
point: VectorWithNorm,
|
||||
centroid: VectorWithNorm): Double = distance(point, centroid)
|
||||
}
|
||||
|
||||
@Since("2.4.0")
|
||||
object DistanceMeasure {
|
||||
|
||||
@Since("2.4.0")
|
||||
val EUCLIDEAN = "euclidean"
|
||||
@Since("2.4.0")
|
||||
val COSINE = "cosine"
|
||||
|
||||
private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
|
||||
distanceMeasure match {
|
||||
case EUCLIDEAN => new EuclideanDistanceMeasure
|
||||
case COSINE => new CosineDistanceMeasure
|
||||
case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
|
||||
s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
|
||||
}
|
||||
|
||||
private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
|
||||
distanceMeasure match {
|
||||
case DistanceMeasure.EUCLIDEAN => true
|
||||
case DistanceMeasure.COSINE => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the squared distance.
|
||||
*/
|
||||
override def findClosest(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = Double.PositiveInfinity
|
||||
var bestIndex = 0
|
||||
var i = 0
|
||||
centers.foreach { center =>
|
||||
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
|
||||
// distance computation.
|
||||
var lowerBoundOfSqDist = center.norm - point.norm
|
||||
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
|
||||
if (lowerBoundOfSqDist < bestDistance) {
|
||||
val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
|
||||
if (distance < bestDistance) {
|
||||
bestDistance = distance
|
||||
bestIndex = i
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(bestIndex, bestDistance)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return whether a center converged or not, given the epsilon parameter.
|
||||
*/
|
||||
override def isCenterConverged(
|
||||
oldCenter: VectorWithNorm,
|
||||
newCenter: VectorWithNorm,
|
||||
epsilon: Double): Boolean = {
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
|
||||
}
|
||||
|
||||
/**
|
||||
* @param v1: first vector
|
||||
* @param v2: second vector
|
||||
* @return the Euclidean distance between the two input vectors
|
||||
*/
|
||||
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
|
||||
Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the total cost of the cluster from its aggregated properties
|
||||
*/
|
||||
override def clusterCost(
|
||||
centroid: VectorWithNorm,
|
||||
pointsSum: VectorWithNorm,
|
||||
numberOfPoints: Long,
|
||||
pointsSquaredNorm: Double): Double = {
|
||||
math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the cost of a point to be assigned to the cluster centroid
|
||||
*/
|
||||
override def cost(
|
||||
point: VectorWithNorm,
|
||||
centroid: VectorWithNorm): Double = {
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(point, centroid)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private[spark] object EuclideanDistanceMeasure {
|
||||
/**
|
||||
* @return the squared Euclidean distance between two vectors computed by
|
||||
* [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
|
||||
*/
|
||||
private[clustering] def fastSquaredDistance(
|
||||
v1: VectorWithNorm,
|
||||
v2: VectorWithNorm): Double = {
|
||||
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class CosineDistanceMeasure extends DistanceMeasure {
|
||||
/**
|
||||
* @param v1: first vector
|
||||
* @param v2: second vector
|
||||
* @return the cosine distance between the two input vectors
|
||||
*/
|
||||
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
|
||||
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
|
||||
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the value of `sum` adding the `point` vector.
|
||||
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
|
||||
* @param sum the `sum` for a cluster to be updated
|
||||
*/
|
||||
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
|
||||
assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
|
||||
axpy(1.0 / point.norm, point.vector, sum)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
|
||||
*
|
||||
* @param sum the `sum` for a cluster
|
||||
* @param count the number of points in the cluster
|
||||
* @return the centroid of the cluster
|
||||
*/
|
||||
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
|
||||
scal(1.0 / count, sum)
|
||||
val norm = Vectors.norm(sum, 2)
|
||||
scal(1.0 / norm, sum)
|
||||
new VectorWithNorm(sum, 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the total cost of the cluster from its aggregated properties
|
||||
*/
|
||||
override def clusterCost(
|
||||
centroid: VectorWithNorm,
|
||||
pointsSum: VectorWithNorm,
|
||||
numberOfPoints: Long,
|
||||
pointsSquaredNorm: Double): Double = {
|
||||
val costVector = pointsSum.vector.copy
|
||||
math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns two new centroids symmetric to the specified centroid applying `noise` with the
|
||||
* with the specified `level`.
|
||||
*
|
||||
* @param level the level of `noise` to apply to the given centroid.
|
||||
* @param noise a noise vector
|
||||
* @param centroid the parent centroid
|
||||
* @return a left and right centroid symmetric to `centroid`
|
||||
*/
|
||||
override def symmetricCentroids(
|
||||
level: Double,
|
||||
noise: Vector,
|
||||
centroid: Vector): (VectorWithNorm, VectorWithNorm) = {
|
||||
val (left, right) = super.symmetricCentroids(level, noise, centroid)
|
||||
val leftVector = left.vector
|
||||
val rightVector = right.vector
|
||||
scal(1.0 / left.norm, leftVector)
|
||||
scal(1.0 / right.norm, rightVector)
|
||||
(new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0))
|
||||
}
|
||||
}
|
|
@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
|
||||
import org.apache.spark.ml.util.Instrumentation
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.mllib.linalg.BLAS.axpy
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -204,7 +203,7 @@ class KMeans private (
|
|||
*/
|
||||
@Since("2.4.0")
|
||||
def setDistanceMeasure(distanceMeasure: String): this.type = {
|
||||
KMeans.validateDistanceMeasure(distanceMeasure)
|
||||
DistanceMeasure.validateDistanceMeasure(distanceMeasure)
|
||||
this.distanceMeasure = distanceMeasure
|
||||
this
|
||||
}
|
||||
|
@ -582,14 +581,6 @@ object KMeans {
|
|||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
|
||||
distanceMeasure match {
|
||||
case DistanceMeasure.EUCLIDEAN => true
|
||||
case DistanceMeasure.COSINE => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -605,186 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double)
|
|||
/** Converts the vector to a dense vector. */
|
||||
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
|
||||
}
|
||||
|
||||
|
||||
private[spark] abstract class DistanceMeasure extends Serializable {
|
||||
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the cost.
|
||||
*/
|
||||
def findClosest(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = Double.PositiveInfinity
|
||||
var bestIndex = 0
|
||||
var i = 0
|
||||
centers.foreach { center =>
|
||||
val currentDistance = distance(center, point)
|
||||
if (currentDistance < bestDistance) {
|
||||
bestDistance = currentDistance
|
||||
bestIndex = i
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(bestIndex, bestDistance)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the K-means cost of a given point against the given cluster centers.
|
||||
*/
|
||||
def pointCost(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
point: VectorWithNorm): Double = {
|
||||
findClosest(centers, point)._2
|
||||
}
|
||||
|
||||
/**
|
||||
* @return whether a center converged or not, given the epsilon parameter.
|
||||
*/
|
||||
def isCenterConverged(
|
||||
oldCenter: VectorWithNorm,
|
||||
newCenter: VectorWithNorm,
|
||||
epsilon: Double): Boolean = {
|
||||
distance(oldCenter, newCenter) <= epsilon
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the cosine distance between two points.
|
||||
*/
|
||||
def distance(
|
||||
v1: VectorWithNorm,
|
||||
v2: VectorWithNorm): Double
|
||||
|
||||
/**
|
||||
* Updates the value of `sum` adding the `point` vector.
|
||||
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
|
||||
* @param sum the `sum` for a cluster to be updated
|
||||
*/
|
||||
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
|
||||
axpy(1.0, point.vector, sum)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
|
||||
*
|
||||
* @param sum the `sum` for a cluster
|
||||
* @param count the number of points in the cluster
|
||||
* @return the centroid of the cluster
|
||||
*/
|
||||
def centroid(sum: Vector, count: Long): VectorWithNorm = {
|
||||
scal(1.0 / count, sum)
|
||||
new VectorWithNorm(sum)
|
||||
}
|
||||
}
|
||||
|
||||
@Since("2.4.0")
|
||||
object DistanceMeasure {
|
||||
|
||||
@Since("2.4.0")
|
||||
val EUCLIDEAN = "euclidean"
|
||||
@Since("2.4.0")
|
||||
val COSINE = "cosine"
|
||||
|
||||
private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
|
||||
distanceMeasure match {
|
||||
case EUCLIDEAN => new EuclideanDistanceMeasure
|
||||
case COSINE => new CosineDistanceMeasure
|
||||
case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
|
||||
s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
|
||||
/**
|
||||
* @return the index of the closest center to the given point, as well as the squared distance.
|
||||
*/
|
||||
override def findClosest(
|
||||
centers: TraversableOnce[VectorWithNorm],
|
||||
point: VectorWithNorm): (Int, Double) = {
|
||||
var bestDistance = Double.PositiveInfinity
|
||||
var bestIndex = 0
|
||||
var i = 0
|
||||
centers.foreach { center =>
|
||||
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
|
||||
// distance computation.
|
||||
var lowerBoundOfSqDist = center.norm - point.norm
|
||||
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
|
||||
if (lowerBoundOfSqDist < bestDistance) {
|
||||
val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
|
||||
if (distance < bestDistance) {
|
||||
bestDistance = distance
|
||||
bestIndex = i
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(bestIndex, bestDistance)
|
||||
}
|
||||
|
||||
/**
|
||||
* @return whether a center converged or not, given the epsilon parameter.
|
||||
*/
|
||||
override def isCenterConverged(
|
||||
oldCenter: VectorWithNorm,
|
||||
newCenter: VectorWithNorm,
|
||||
epsilon: Double): Boolean = {
|
||||
EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
|
||||
}
|
||||
|
||||
/**
|
||||
* @param v1: first vector
|
||||
* @param v2: second vector
|
||||
* @return the Euclidean distance between the two input vectors
|
||||
*/
|
||||
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
|
||||
Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private[spark] object EuclideanDistanceMeasure {
|
||||
/**
|
||||
* @return the squared Euclidean distance between two vectors computed by
|
||||
* [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
|
||||
*/
|
||||
private[clustering] def fastSquaredDistance(
|
||||
v1: VectorWithNorm,
|
||||
v2: VectorWithNorm): Double = {
|
||||
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class CosineDistanceMeasure extends DistanceMeasure {
|
||||
/**
|
||||
* @param v1: first vector
|
||||
* @param v2: second vector
|
||||
* @return the cosine distance between the two input vectors
|
||||
*/
|
||||
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
|
||||
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
|
||||
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the value of `sum` adding the `point` vector.
|
||||
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
|
||||
* @param sum the `sum` for a cluster to be updated
|
||||
*/
|
||||
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
|
||||
axpy(1.0 / point.norm, point.vector, sum)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
|
||||
*
|
||||
* @param sum the `sum` for a cluster
|
||||
* @param count the number of points in the cluster
|
||||
* @return the centroid of the cluster
|
||||
*/
|
||||
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
|
||||
scal(1.0 / count, sum)
|
||||
val norm = Vectors.norm(sum, 2)
|
||||
scal(1.0 / norm, sum)
|
||||
new VectorWithNorm(sum, 1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
package org.apache.spark.ml.clustering
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.clustering.DistanceMeasure
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
|
@ -140,6 +142,46 @@ class BisectingKMeansSuite
|
|||
testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
|
||||
BisectingKMeansSuite.allParamSettings, checkModelData)
|
||||
}
|
||||
|
||||
test("BisectingKMeans with cosine distance is not supported for 0-length vectors") {
|
||||
val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1)
|
||||
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
|
||||
Vectors.dense(0.0, 0.0),
|
||||
Vectors.dense(10.0, 10.0),
|
||||
Vectors.dense(1.0, 0.5)
|
||||
)).map(v => TestRow(v)))
|
||||
val e = intercept[SparkException](model.fit(df))
|
||||
assert(e.getCause.isInstanceOf[AssertionError])
|
||||
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
|
||||
}
|
||||
|
||||
test("BisectingKMeans with cosine distance") {
|
||||
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
|
||||
Vectors.dense(1.0, 1.0),
|
||||
Vectors.dense(10.0, 10.0),
|
||||
Vectors.dense(1.0, 0.5),
|
||||
Vectors.dense(10.0, 4.4),
|
||||
Vectors.dense(-1.0, 1.0),
|
||||
Vectors.dense(-100.0, 90.0)
|
||||
)).map(v => TestRow(v)))
|
||||
val model = new BisectingKMeans()
|
||||
.setK(3)
|
||||
.setDistanceMeasure(DistanceMeasure.COSINE)
|
||||
.setSeed(1)
|
||||
.fit(df)
|
||||
val predictionDf = model.transform(df)
|
||||
assert(predictionDf.select("prediction").distinct().count() == 3)
|
||||
val predictionsMap = predictionDf.collect().map(row =>
|
||||
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
|
||||
assert(predictionsMap(Vectors.dense(1.0, 1.0)) ==
|
||||
predictionsMap(Vectors.dense(10.0, 10.0)))
|
||||
assert(predictionsMap(Vectors.dense(1.0, 0.5)) ==
|
||||
predictionsMap(Vectors.dense(10.0, 4.4)))
|
||||
assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
|
||||
predictionsMap(Vectors.dense(-100.0, 90.0)))
|
||||
|
||||
model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
object BisectingKMeansSuite {
|
||||
|
|
|
@ -36,6 +36,12 @@ object MimaExcludes {
|
|||
|
||||
// Exclude rules for 2.4.x
|
||||
lazy val v24excludes = v23excludes ++ Seq(
|
||||
// [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"),
|
||||
|
||||
// [SPARK-20659] Remove StorageStatus, or make it private
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"),
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"),
|
||||
|
|
Loading…
Reference in a new issue