[SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length

## What changes were proposed in this pull request?

Add threshoulds' length checking for Classifiers which extends ProbabilisticClassifier

## How was this patch tested?

unit tests and manual tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #14470 from zhengruifeng/classifier_check_setThreshoulds_length.
This commit is contained in:
Zheng RuiFeng 2016-08-04 21:44:54 +01:00 committed by Sean Owen
parent 1d781572e8
commit 0e2e5d7d0b
4 changed files with 28 additions and 0 deletions

View file

@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)

View file

@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)

View file

@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") (
setDefault(modelType -> OldNaiveBayes.Multinomial)
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
val oldDataset: RDD[OldLabeledPoint] =
extractLabeledPoints(dataset).map(OldLabeledPoint.fromML)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))

View file

@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)