[SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading
## What changes were proposed in this pull request? Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its `impurity` param by default contains the value `gini`: it is not even modifiable by the user and it differs from the actual impurity used, which is `variance`. This issue does not limit to a wrong value returned for it if the user queries by `getImpurity`, but it also affect the load of a saved model, as its `impurityStats` are created as `gini` (since this is the value stored for the model impurity) which leads to wrong `featureImportances` in model loaded from saved ones. The PR changes the `impurity` param used to one which allows only the value `variance`. ## How was this patch tested? modified UT Closes #22986 from mgaido91/SPARK-25959. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
e557c53c59
commit
e00cac9898
|
@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
|
|||
s" trees based on metadata but found ${trees.length} trees.")
|
||||
val model = new GBTClassificationModel(metadata.uid,
|
||||
trees, treeWeights, numFeatures)
|
||||
metadata.getAndSetParams(model)
|
||||
// We ignore the impurity while loading models because in previous models it was wrongly
|
||||
// set to gini (see SPARK-25959).
|
||||
metadata.getAndSetParams(model, Some(List("impurity")))
|
||||
model
|
||||
}
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
@Since("1.4.0")
|
||||
object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
|
||||
/** Accessor for supported impurities: variance */
|
||||
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
|
||||
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
|
||||
|
||||
@Since("2.0.0")
|
||||
override def load(path: String): DecisionTreeRegressor = super.load(path)
|
||||
|
|
|
@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
|
||||
/** Accessor for supported impurity settings: variance */
|
||||
@Since("1.4.0")
|
||||
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
|
||||
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
|
||||
|
||||
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
||||
@Since("1.4.0")
|
||||
|
|
|
@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams {
|
|||
private[ml] trait DecisionTreeClassifierParams
|
||||
extends DecisionTreeParams with TreeClassifierParams
|
||||
|
||||
/**
|
||||
* Parameters for Decision Tree-based regression algorithms.
|
||||
*/
|
||||
private[ml] trait TreeRegressorParams extends Params {
|
||||
|
||||
private[ml] trait HasVarianceImpurity extends Params {
|
||||
/**
|
||||
* Criterion used for information gain calculation (case-insensitive).
|
||||
* Supported: "variance".
|
||||
|
@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params {
|
|||
*/
|
||||
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
|
||||
" information gain calculation (case-insensitive). Supported options:" +
|
||||
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
|
||||
s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
|
||||
(value: String) =>
|
||||
TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
|
||||
HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
|
||||
|
||||
setDefault(impurity -> "variance")
|
||||
|
||||
|
@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params {
|
|||
}
|
||||
}
|
||||
|
||||
private[ml] object TreeRegressorParams {
|
||||
private[ml] object HasVarianceImpurity {
|
||||
// These options should be lowercase.
|
||||
final val supportedImpurities: Array[String] =
|
||||
Array("variance").map(_.toLowerCase(Locale.ROOT))
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for Decision Tree-based regression algorithms.
|
||||
*/
|
||||
private[ml] trait TreeRegressorParams extends HasVarianceImpurity
|
||||
|
||||
private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
|
||||
with TreeRegressorParams with HasVarianceCol {
|
||||
|
||||
|
@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams {
|
|||
Array("logistic").map(_.toLowerCase(Locale.ROOT))
|
||||
}
|
||||
|
||||
private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
|
||||
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity {
|
||||
|
||||
/**
|
||||
* Loss function which GBT tries to minimize. (case-insensitive)
|
||||
|
|
|
@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
model2: GBTClassificationModel): Unit = {
|
||||
TreeTests.checkEqual(model, model2)
|
||||
assert(model.numFeatures === model2.numFeatures)
|
||||
assert(model.featureImportances == model2.featureImportances)
|
||||
}
|
||||
|
||||
val gbt = new GBTClassifier()
|
||||
|
|
|
@ -36,6 +36,17 @@ object MimaExcludes {
|
|||
|
||||
// Exclude rules for 3.0.x
|
||||
lazy val v30excludes = v24excludes ++ Seq(
|
||||
// [SPARK-25959] GBTClassifier picks wrong impurity stats on loading
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"),
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"),
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"),
|
||||
|
||||
// [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),
|
||||
|
|
Loading…
Reference in a new issue