[SPARK-14569][ML] Log instrumentation in KMeans
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14569 Log instrumentation in KMeans: - featuresCol - predictionCol - k - initMode - initSteps - maxIter - seed - tol - summary ## How was this patch tested? Manually test on local machine, by running and checking output of org.apache.spark.examples.ml.KMeansExample Author: Xin Ren <iamshrek@126.com> Closes #12432 from keypointt/SPARK-14569.
This commit is contained in:
parent
411454475a
commit
6d1e4c4a65
|
@ -264,6 +264,9 @@ class KMeans @Since("1.5.0") (
|
|||
override def fit(dataset: Dataset[_]): KMeansModel = {
|
||||
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||
|
||||
val instr = Instrumentation.create(this, rdd)
|
||||
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
|
||||
|
||||
val algo = new MLlibKMeans()
|
||||
.setK($(k))
|
||||
.setInitializationMode($(initMode))
|
||||
|
@ -271,11 +274,13 @@ class KMeans @Since("1.5.0") (
|
|||
.setMaxIterations($(maxIter))
|
||||
.setSeed($(seed))
|
||||
.setEpsilon($(tol))
|
||||
val parentModel = algo.run(rdd)
|
||||
val parentModel = algo.run(rdd, Option(instr))
|
||||
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
|
||||
val summary = new KMeansSummary(
|
||||
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
|
||||
model.setSummary(summary)
|
||||
val m = model.setSummary(summary)
|
||||
instr.logSuccess(m)
|
||||
m
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
|
|
|
@ -39,7 +39,7 @@ import org.apache.spark.sql.Dataset
|
|||
* @param dataset the training dataset
|
||||
* @tparam E the type of the estimator
|
||||
*/
|
||||
private[ml] class Instrumentation[E <: Estimator[_]] private (
|
||||
private[spark] class Instrumentation[E <: Estimator[_]] private (
|
||||
estimator: E, dataset: RDD[_]) extends Logging {
|
||||
|
||||
private val id = Instrumentation.counter.incrementAndGet()
|
||||
|
@ -95,7 +95,7 @@ private[ml] class Instrumentation[E <: Estimator[_]] private (
|
|||
/**
|
||||
* Some common methods for logging information about a training session.
|
||||
*/
|
||||
private[ml] object Instrumentation {
|
||||
private[spark] object Instrumentation {
|
||||
private val counter = new AtomicLong(0)
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
|
||||
import org.apache.spark.ml.util.Instrumentation
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
@ -212,6 +214,12 @@ class KMeans private (
|
|||
*/
|
||||
@Since("0.8.0")
|
||||
def run(data: RDD[Vector]): KMeansModel = {
|
||||
run(data, None)
|
||||
}
|
||||
|
||||
private[spark] def run(
|
||||
data: RDD[Vector],
|
||||
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
|
||||
|
||||
if (data.getStorageLevel == StorageLevel.NONE) {
|
||||
logWarning("The input data is not directly cached, which may hurt performance if its"
|
||||
|
@ -224,7 +232,7 @@ class KMeans private (
|
|||
val zippedData = data.zip(norms).map { case (v, norm) =>
|
||||
new VectorWithNorm(v, norm)
|
||||
}
|
||||
val model = runAlgorithm(zippedData)
|
||||
val model = runAlgorithm(zippedData, instr)
|
||||
norms.unpersist()
|
||||
|
||||
// Warn at the end of the run as well, for increased visibility.
|
||||
|
@ -238,7 +246,9 @@ class KMeans private (
|
|||
/**
|
||||
* Implementation of K-Means algorithm.
|
||||
*/
|
||||
private def runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {
|
||||
private def runAlgorithm(
|
||||
data: RDD[VectorWithNorm],
|
||||
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
|
||||
|
||||
val sc = data.sparkContext
|
||||
|
||||
|
@ -274,6 +284,8 @@ class KMeans private (
|
|||
|
||||
val iterationStartTime = System.nanoTime()
|
||||
|
||||
instr.map(_.logNumFeatures(centers(0)(0).vector.size))
|
||||
|
||||
// Execute iterations of Lloyd's algorithm until all runs have converged
|
||||
while (iteration < maxIterations && !activeRuns.isEmpty) {
|
||||
type WeightedPoint = (Vector, Long)
|
||||
|
|
Loading…
Reference in a new issue