[SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading

## What changes were proposed in this pull request?

Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its `impurity` param by default contains the value `gini`: it is not even modifiable by the user and it differs from the actual impurity used, which is `variance`. This issue does not limit to a wrong value returned for it if the user queries by `getImpurity`, but it also affect the load of a saved model, as its `impurityStats` are created as `gini` (since this is the value stored for the model impurity) which leads to wrong `featureImportances` in model loaded from saved ones.

The PR changes the `impurity` param used to one which allows only the value `variance`.

## How was this patch tested?

modified UT

Closes #22986 from mgaido91/SPARK-25959.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Marco Gaido 2018-11-17 09:46:45 -06:00 committed by Sean Owen
parent e557c53c59
commit e00cac9898
6 changed files with 27 additions and 12 deletions

View file

@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
metadata.getAndSetParams(model)
// We ignore the impurity while loading models because in previous models it was wrongly
// set to gini (see SPARK-25959).
metadata.getAndSetParams(model, Some(List("impurity")))
model
}
}

View file

@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("1.4.0")
object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
@Since("2.0.0")
override def load(path: String): DecisionTreeRegressor = super.load(path)

View file

@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")

View file

@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams {
private[ml] trait DecisionTreeClassifierParams
extends DecisionTreeParams with TreeClassifierParams
/**
* Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends Params {
private[ml] trait HasVarianceImpurity extends Params {
/**
* Criterion used for information gain calculation (case-insensitive).
* Supported: "variance".
@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params {
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
(value: String) =>
TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
setDefault(impurity -> "variance")
@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params {
}
}
private[ml] object TreeRegressorParams {
private[ml] object HasVarianceImpurity {
// These options should be lowercase.
final val supportedImpurities: Array[String] =
Array("variance").map(_.toLowerCase(Locale.ROOT))
}
/**
* Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends HasVarianceImpurity
private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
with TreeRegressorParams with HasVarianceCol {
@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams {
Array("logistic").map(_.toLowerCase(Locale.ROOT))
}
private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity {
/**
* Loss function which GBT tries to minimize. (case-insensitive)

View file

@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
model2: GBTClassificationModel): Unit = {
TreeTests.checkEqual(model, model2)
assert(model.numFeatures === model2.numFeatures)
assert(model.featureImportances == model2.featureImportances)
}
val gbt = new GBTClassifier()

View file

@ -36,6 +36,17 @@ object MimaExcludes {
// Exclude rules for 3.0.x
lazy val v30excludes = v24excludes ++ Seq(
// [SPARK-25959] GBTClassifier picks wrong impurity stats on loading
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"),
// [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),