[SPARK-32676][3.0][ML] Fix double caching in KMeans/BiKMeans
### What changes were proposed in this pull request? Fix double caching in KMeans/BiKMeans: 1, let the callers of `runWithWeight` to pass whether `handlePersistence` is needed; 2, persist and unpersist inside of `runWithWeight`; 3, persist the `norms` if needed according to the comments; ### Why are the changes needed? avoid double caching ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing testsuites Closes #29501 from zhengruifeng/kmeans_handlePersistence. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
1c798f973f
commit
ac520d4a7c
|
@ -29,9 +29,8 @@ import org.apache.spark.ml.util._
|
|||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
|
||||
BisectingKMeansModel => MLlibBisectingKMeansModel}
|
||||
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
|
||||
|
@ -276,21 +275,6 @@ class BisectingKMeans @Since("2.0.0") (
|
|||
override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr =>
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
||||
val instances: RDD[(OldVector, Double)] = dataset
|
||||
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w).rdd.map {
|
||||
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
|
||||
}
|
||||
if (handlePersistence) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(dataset)
|
||||
instr.logParams(this, featuresCol, predictionCol, k, maxIter, seed,
|
||||
|
@ -302,11 +286,18 @@ class BisectingKMeans @Since("2.0.0") (
|
|||
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
|
||||
.setSeed($(seed))
|
||||
.setDistanceMeasure($(distanceMeasure))
|
||||
val parentModel = bkm.runWithWeight(instances, Some(instr))
|
||||
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
|
||||
if (handlePersistence) {
|
||||
instances.unpersist()
|
||||
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
val instances = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w)
|
||||
.rdd.map { case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight) }
|
||||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val parentModel = bkm.runWithWeight(instances, handlePersistence, Some(instr))
|
||||
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
|
||||
|
||||
val summary = new BisectingKMeansSummary(
|
||||
model.transform(dataset),
|
||||
|
|
|
@ -32,7 +32,6 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
|
|||
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
|
||||
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
|
||||
|
@ -330,22 +329,6 @@ class KMeans @Since("1.5.0") (
|
|||
override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr =>
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
||||
val instances: RDD[(OldVector, Double)] = dataset
|
||||
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w).rdd.map {
|
||||
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
|
||||
}
|
||||
|
||||
if (handlePersistence) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(dataset)
|
||||
instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
|
||||
|
@ -358,8 +341,19 @@ class KMeans @Since("1.5.0") (
|
|||
.setSeed($(seed))
|
||||
.setEpsilon($(tol))
|
||||
.setDistanceMeasure($(distanceMeasure))
|
||||
val parentModel = algo.runWithWeight(instances, Option(instr))
|
||||
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
val instances = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w)
|
||||
.rdd.map { case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight) }
|
||||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val parentModel = algo.runWithWeight(instances, handlePersistence, Some(instr))
|
||||
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
|
||||
|
||||
val summary = new KMeansSummary(
|
||||
model.transform(dataset),
|
||||
$(predictionCol),
|
||||
|
@ -370,9 +364,6 @@ class KMeans @Since("1.5.0") (
|
|||
|
||||
model.setSummary(Some(summary))
|
||||
instr.logNamedValue("clusterSizes", summary.clusterSizes)
|
||||
if (handlePersistence) {
|
||||
instances.unpersist()
|
||||
}
|
||||
model
|
||||
}
|
||||
|
||||
|
|
|
@ -153,30 +153,25 @@ class BisectingKMeans private (
|
|||
this
|
||||
}
|
||||
|
||||
private[spark] def run(
|
||||
input: RDD[Vector],
|
||||
instr: Option[Instrumentation]): BisectingKMeansModel = {
|
||||
val instances: RDD[(Vector, Double)] = input.map {
|
||||
case (point) => (point, 1.0)
|
||||
}
|
||||
runWithWeight(instances, None)
|
||||
}
|
||||
|
||||
private[spark] def runWithWeight(
|
||||
input: RDD[(Vector, Double)],
|
||||
instances: RDD[(Vector, Double)],
|
||||
handlePersistence: Boolean,
|
||||
instr: Option[Instrumentation]): BisectingKMeansModel = {
|
||||
val d = input.map(_._1.size).first
|
||||
val d = instances.map(_._1.size).first
|
||||
logInfo(s"Feature dimension: $d.")
|
||||
|
||||
val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
|
||||
// Compute and cache vector norms for fast distance computation.
|
||||
val norms = input.map(d => Vectors.norm(d._1, 2.0))
|
||||
val vectors = input.zip(norms).map {
|
||||
case ((x, weight), norm) => new VectorWithNorm(x, norm, weight)
|
||||
}
|
||||
if (input.getStorageLevel == StorageLevel.NONE) {
|
||||
val dMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
|
||||
val norms = instances.map(d => Vectors.norm(d._1, 2.0))
|
||||
val vectors = instances.zip(norms)
|
||||
.map { case ((x, weight), norm) => new VectorWithNorm(x, norm, weight) }
|
||||
|
||||
if (handlePersistence) {
|
||||
vectors.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
} else {
|
||||
// Compute and cache vector norms for fast distance computation.
|
||||
norms.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
var assignments = vectors.map(v => (ROOT_INDEX, v))
|
||||
var activeClusters = summarize(d, assignments, dMeasure)
|
||||
instr.foreach(_.logNumExamples(activeClusters.values.map(_.size).sum))
|
||||
|
@ -244,13 +239,11 @@ class BisectingKMeans private (
|
|||
}
|
||||
level += 1
|
||||
}
|
||||
if (preIndices != null) {
|
||||
preIndices.unpersist()
|
||||
}
|
||||
if (indices != null) {
|
||||
indices.unpersist()
|
||||
}
|
||||
vectors.unpersist()
|
||||
|
||||
if (preIndices != null) { preIndices.unpersist() }
|
||||
if (indices != null) { indices.unpersist() }
|
||||
if (handlePersistence) { vectors.unpersist() } else { norms.unpersist() }
|
||||
|
||||
val clusters = activeClusters ++ inactiveClusters
|
||||
val root = buildTree(clusters, dMeasure)
|
||||
val totalCost = root.leafNodes.map(_.cost).sum
|
||||
|
@ -264,7 +257,9 @@ class BisectingKMeans private (
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
def run(input: RDD[Vector]): BisectingKMeansModel = {
|
||||
run(input, None)
|
||||
val instances = input.map(point => (point, 1.0))
|
||||
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
|
||||
runWithWeight(instances, handlePersistence, None)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -210,27 +210,26 @@ class KMeans private (
|
|||
@Since("0.8.0")
|
||||
def run(data: RDD[Vector]): KMeansModel = {
|
||||
val instances = data.map(point => (point, 1.0))
|
||||
runWithWeight(instances, None)
|
||||
val handlePersistence = data.getStorageLevel == StorageLevel.NONE
|
||||
runWithWeight(instances, handlePersistence, None)
|
||||
}
|
||||
|
||||
private[spark] def runWithWeight(
|
||||
data: RDD[(Vector, Double)],
|
||||
instances: RDD[(Vector, Double)],
|
||||
handlePersistence: Boolean,
|
||||
instr: Option[Instrumentation]): KMeansModel = {
|
||||
val norms = instances.map { case (v, _) => Vectors.norm(v, 2.0) }
|
||||
val vectors = instances.zip(norms)
|
||||
.map { case ((v, w), norm) => new VectorWithNorm(v, norm, w) }
|
||||
|
||||
// Compute squared norms and cache them.
|
||||
val norms = data.map { case (v, _) =>
|
||||
Vectors.norm(v, 2.0)
|
||||
if (handlePersistence) {
|
||||
vectors.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
} else {
|
||||
// Compute squared norms and cache them.
|
||||
norms.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
val zippedData = data.zip(norms).map { case ((v, w), norm) =>
|
||||
new VectorWithNorm(v, norm, w)
|
||||
}
|
||||
|
||||
if (data.getStorageLevel == StorageLevel.NONE) {
|
||||
zippedData.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
val model = runAlgorithmWithWeight(zippedData, instr)
|
||||
zippedData.unpersist()
|
||||
val model = runAlgorithmWithWeight(vectors, instr)
|
||||
if (handlePersistence) { vectors.unpersist() } else { norms.unpersist() }
|
||||
|
||||
model
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue