[SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length
## What changes were proposed in this pull request? Add threshoulds' length checking for Classifiers which extends ProbabilisticClassifier ## How was this patch tested? unit tests and manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #14470 from zhengruifeng/classifier_check_setThreshoulds_length.
This commit is contained in:
parent
1d781572e8
commit
0e2e5d7d0b
|
@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val numClasses: Int = getNumClasses(dataset)
|
||||
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".train() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
||||
val strategy = getOldStrategy(categoricalFeatures, numClasses)
|
||||
|
||||
|
|
|
@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") (
|
|||
val numClasses = histogram.length
|
||||
val numFeatures = summarizer.mean.size
|
||||
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".train() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
instr.logNumClasses(numClasses)
|
||||
instr.logNumFeatures(numFeatures)
|
||||
|
||||
|
|
|
@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") (
|
|||
setDefault(modelType -> OldNaiveBayes.Multinomial)
|
||||
|
||||
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
|
||||
val numClasses = getNumClasses(dataset)
|
||||
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".train() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
val oldDataset: RDD[OldLabeledPoint] =
|
||||
extractLabeledPoints(dataset).map(OldLabeledPoint.fromML)
|
||||
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
|
||||
|
|
|
@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") (
|
|||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val numClasses: Int = getNumClasses(dataset)
|
||||
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".train() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
||||
val strategy =
|
||||
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
|
||||
|
|
Loading…
Reference in a new issue