[SPARK-24132][ML] Instrumentation improvement for classification

## What changes were proposed in this pull request?

- Add OptionalInstrumentation as argument for getNumClasses in ml.classification.Classifier

- Change the function call for getNumClasses in train() in ml.classification.DecisionTreeClassifier, ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes

- Modify the instrumentation creation in ml.classification.LinearSVC

- Change the log call in ml.classification.OneVsRest and ml.classification.LinearSVC

## How was this patch tested?

Manual.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Lu WANG <lu.wang@databricks.com>

Closes #21204 from ludatabricks/SPARK-23686.
This commit is contained in:
Lu WANG 2018-05-08 21:20:58 -07:00 committed by Xiangrui Meng
parent 9498e528d2
commit 7e7350285d
5 changed files with 19 additions and 10 deletions

View file

@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setSeed(value: Long): this.type = set(seed, value)
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
instr.logNumClasses(numClasses)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))

View file

@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
Instance(label, weight, features)
}
val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
aggregationDepth)
@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") (
if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
logError(msg)
instr.logError(msg)
throw new SparkException(msg)
}
@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") (
bcFeaturesStd.destroy(blocking = false)
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
logError(msg)
instr.logError(msg)
throw new SparkException(msg)
}

View file

@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") (
private[spark] def trainWithLabelCheck(
dataset: Dataset[_],
positiveLabel: Boolean): NaiveBayesModel = {
val instr = Instrumentation.create(this, dataset)
if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
instr.logNumClasses(numClasses)
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") (
}
}
val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
probabilityCol, modelType, smoothing, thresholds)

View file

@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") (
transformSchema(dataset.schema)
val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
// determine number of classes either from metadata if provided, or via computation.
@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") (
getClassifier match {
case _: HasWeightCol => true
case c =>
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}

View file

@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") (
set(featureSubsetStrategy, value)
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") (
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
instr.logSuccess(m)
m
}