[SPARK-14569][ML] Log instrumentation in KMeans
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14569 Log instrumentation in KMeans: - featuresCol - predictionCol - k - initMode - initSteps - maxIter - seed - tol - summary ## How was this patch tested? Manually test on local machine, by running and checking output of org.apache.spark.examples.ml.KMeansExample Author: Xin Ren <iamshrek@126.com> Closes #12432 from keypointt/SPARK-14569.
This commit is contained in:
parent
411454475a
commit
6d1e4c4a65
|
@ -264,6 +264,9 @@ class KMeans @Since("1.5.0") (
|
||||||
override def fit(dataset: Dataset[_]): KMeansModel = {
|
override def fit(dataset: Dataset[_]): KMeansModel = {
|
||||||
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
|
||||||
|
|
||||||
|
val instr = Instrumentation.create(this, rdd)
|
||||||
|
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
|
||||||
|
|
||||||
val algo = new MLlibKMeans()
|
val algo = new MLlibKMeans()
|
||||||
.setK($(k))
|
.setK($(k))
|
||||||
.setInitializationMode($(initMode))
|
.setInitializationMode($(initMode))
|
||||||
|
@ -271,11 +274,13 @@ class KMeans @Since("1.5.0") (
|
||||||
.setMaxIterations($(maxIter))
|
.setMaxIterations($(maxIter))
|
||||||
.setSeed($(seed))
|
.setSeed($(seed))
|
||||||
.setEpsilon($(tol))
|
.setEpsilon($(tol))
|
||||||
val parentModel = algo.run(rdd)
|
val parentModel = algo.run(rdd, Option(instr))
|
||||||
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
|
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
|
||||||
val summary = new KMeansSummary(
|
val summary = new KMeansSummary(
|
||||||
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
|
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
|
||||||
model.setSummary(summary)
|
val m = model.setSummary(summary)
|
||||||
|
instr.logSuccess(m)
|
||||||
|
m
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
|
|
|
@ -39,7 +39,7 @@ import org.apache.spark.sql.Dataset
|
||||||
* @param dataset the training dataset
|
* @param dataset the training dataset
|
||||||
* @tparam E the type of the estimator
|
* @tparam E the type of the estimator
|
||||||
*/
|
*/
|
||||||
private[ml] class Instrumentation[E <: Estimator[_]] private (
|
private[spark] class Instrumentation[E <: Estimator[_]] private (
|
||||||
estimator: E, dataset: RDD[_]) extends Logging {
|
estimator: E, dataset: RDD[_]) extends Logging {
|
||||||
|
|
||||||
private val id = Instrumentation.counter.incrementAndGet()
|
private val id = Instrumentation.counter.incrementAndGet()
|
||||||
|
@ -95,7 +95,7 @@ private[ml] class Instrumentation[E <: Estimator[_]] private (
|
||||||
/**
|
/**
|
||||||
* Some common methods for logging information about a training session.
|
* Some common methods for logging information about a training session.
|
||||||
*/
|
*/
|
||||||
private[ml] object Instrumentation {
|
private[spark] object Instrumentation {
|
||||||
private val counter = new AtomicLong(0)
|
private val counter = new AtomicLong(0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
|
import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
|
||||||
|
import org.apache.spark.ml.util.Instrumentation
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
|
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
@ -212,6 +214,12 @@ class KMeans private (
|
||||||
*/
|
*/
|
||||||
@Since("0.8.0")
|
@Since("0.8.0")
|
||||||
def run(data: RDD[Vector]): KMeansModel = {
|
def run(data: RDD[Vector]): KMeansModel = {
|
||||||
|
run(data, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] def run(
|
||||||
|
data: RDD[Vector],
|
||||||
|
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
|
||||||
|
|
||||||
if (data.getStorageLevel == StorageLevel.NONE) {
|
if (data.getStorageLevel == StorageLevel.NONE) {
|
||||||
logWarning("The input data is not directly cached, which may hurt performance if its"
|
logWarning("The input data is not directly cached, which may hurt performance if its"
|
||||||
|
@ -224,7 +232,7 @@ class KMeans private (
|
||||||
val zippedData = data.zip(norms).map { case (v, norm) =>
|
val zippedData = data.zip(norms).map { case (v, norm) =>
|
||||||
new VectorWithNorm(v, norm)
|
new VectorWithNorm(v, norm)
|
||||||
}
|
}
|
||||||
val model = runAlgorithm(zippedData)
|
val model = runAlgorithm(zippedData, instr)
|
||||||
norms.unpersist()
|
norms.unpersist()
|
||||||
|
|
||||||
// Warn at the end of the run as well, for increased visibility.
|
// Warn at the end of the run as well, for increased visibility.
|
||||||
|
@ -238,7 +246,9 @@ class KMeans private (
|
||||||
/**
|
/**
|
||||||
* Implementation of K-Means algorithm.
|
* Implementation of K-Means algorithm.
|
||||||
*/
|
*/
|
||||||
private def runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {
|
private def runAlgorithm(
|
||||||
|
data: RDD[VectorWithNorm],
|
||||||
|
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
|
||||||
|
|
||||||
val sc = data.sparkContext
|
val sc = data.sparkContext
|
||||||
|
|
||||||
|
@ -274,6 +284,8 @@ class KMeans private (
|
||||||
|
|
||||||
val iterationStartTime = System.nanoTime()
|
val iterationStartTime = System.nanoTime()
|
||||||
|
|
||||||
|
instr.map(_.logNumFeatures(centers(0)(0).vector.size))
|
||||||
|
|
||||||
// Execute iterations of Lloyd's algorithm until all runs have converged
|
// Execute iterations of Lloyd's algorithm until all runs have converged
|
||||||
while (iteration < maxIterations && !activeRuns.isEmpty) {
|
while (iteration < maxIterations && !activeRuns.isEmpty) {
|
||||||
type WeightedPoint = (Vector, Long)
|
type WeightedPoint = (Vector, Long)
|
||||||
|
|
Loading…
Reference in a new issue