[SPARK-24132][ML] Instrumentation improvement for classification
## What changes were proposed in this pull request? - Add OptionalInstrumentation as argument for getNumClasses in ml.classification.Classifier - Change the function call for getNumClasses in train() in ml.classification.DecisionTreeClassifier, ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes - Modify the instrumentation creation in ml.classification.LinearSVC - Change the log call in ml.classification.OneVsRest and ml.classification.LinearSVC ## How was this patch tested? Manual. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG <lu.wang@databricks.com> Closes #21204 from ludatabricks/SPARK-23686.
This commit is contained in:
parent
9498e528d2
commit
7e7350285d
|
@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
override def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
|
||||
val instr = Instrumentation.create(this, dataset)
|
||||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val numClasses: Int = getNumClasses(dataset)
|
||||
instr.logNumClasses(numClasses)
|
||||
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
|
@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
||||
val strategy = getOldStrategy(categoricalFeatures, numClasses)
|
||||
|
||||
val instr = Instrumentation.create(this, oldDataset)
|
||||
instr.logParams(params: _*)
|
||||
instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||
cacheNodeIds, checkpointInterval, impurity, seed)
|
||||
|
||||
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||
|
@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
private[ml] def train(data: RDD[LabeledPoint],
|
||||
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
|
||||
val instr = Instrumentation.create(this, data)
|
||||
instr.logParams(params: _*)
|
||||
instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||
cacheNodeIds, checkpointInterval, impurity, seed)
|
||||
|
||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = 0L, instr = Some(instr), parentUID = Some(uid))
|
||||
|
|
|
@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
|
|||
Instance(label, weight, features)
|
||||
}
|
||||
|
||||
val instr = Instrumentation.create(this, instances)
|
||||
val instr = Instrumentation.create(this, dataset)
|
||||
instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
|
||||
aggregationDepth)
|
||||
|
||||
|
@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") (
|
|||
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
|
||||
)(seqOp, combOp, $(aggregationDepth))
|
||||
}
|
||||
instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
|
||||
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
|
||||
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
|
||||
|
||||
val histogram = labelSummarizer.histogram
|
||||
val numInvalid = labelSummarizer.countInvalid
|
||||
|
@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") (
|
|||
if (numInvalid != 0) {
|
||||
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
|
||||
s"Found $numInvalid invalid labels."
|
||||
logError(msg)
|
||||
instr.logError(msg)
|
||||
throw new SparkException(msg)
|
||||
}
|
||||
|
||||
|
@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") (
|
|||
bcFeaturesStd.destroy(blocking = false)
|
||||
if (state == null) {
|
||||
val msg = s"${optimizer.getClass.getName} failed."
|
||||
logError(msg)
|
||||
instr.logError(msg)
|
||||
throw new SparkException(msg)
|
||||
}
|
||||
|
||||
|
|
|
@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") (
|
|||
private[spark] def trainWithLabelCheck(
|
||||
dataset: Dataset[_],
|
||||
positiveLabel: Boolean): NaiveBayesModel = {
|
||||
val instr = Instrumentation.create(this, dataset)
|
||||
if (positiveLabel && isDefined(thresholds)) {
|
||||
val numClasses = getNumClasses(dataset)
|
||||
instr.logNumClasses(numClasses)
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".train() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
|
@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") (
|
|||
}
|
||||
}
|
||||
|
||||
val instr = Instrumentation.create(this, dataset)
|
||||
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
|
||||
probabilityCol, modelType, smoothing, thresholds)
|
||||
|
||||
|
|
|
@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") (
|
|||
transformSchema(dataset.schema)
|
||||
|
||||
val instr = Instrumentation.create(this, dataset)
|
||||
instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
|
||||
instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
|
||||
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
|
||||
|
||||
// determine number of classes either from metadata if provided, or via computation.
|
||||
|
@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") (
|
|||
getClassifier match {
|
||||
case _: HasWeightCol => true
|
||||
case c =>
|
||||
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
|
||||
instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.")
|
||||
false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") (
|
|||
set(featureSubsetStrategy, value)
|
||||
|
||||
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
|
||||
val instr = Instrumentation.create(this, dataset)
|
||||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val numClasses: Int = getNumClasses(dataset)
|
||||
|
@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") (
|
|||
val strategy =
|
||||
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
|
||||
|
||||
val instr = Instrumentation.create(this, oldDataset)
|
||||
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
|
||||
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
|
||||
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
|
||||
|
@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") (
|
|||
|
||||
val numFeatures = oldDataset.first().features.size
|
||||
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
|
||||
instr.logNumClasses(numClasses)
|
||||
instr.logNumFeatures(numFeatures)
|
||||
instr.logSuccess(m)
|
||||
m
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue