[SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor
## What changes were proposed in this pull request? (Provided featureSubset Strategy to GBTClassifier a) Moved featureSubsetStrategy to TreeEnsembleParams b) Changed GBTClassifier to pass featureSubsetStrategy val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)) ## How was this patch tested? a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier b)Added test cases in GBTClassifierSuite and GBTRegressorSuite Author: Pralabh Kumar <pralabhkumar@gmail.com> Closes #18118 from pralabhkumar/develop.
This commit is contained in:
parent
28ab5bf597
commit
9b9827759a
|
@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample {
|
|||
.setLabelCol("indexedLabel")
|
||||
.setFeaturesCol("indexedFeatures")
|
||||
.setMaxIter(10)
|
||||
.setFeatureSubsetStrategy("auto")
|
||||
|
||||
// Convert indexed labels back to original labels.
|
||||
val labelConverter = new IndexToString()
|
||||
|
|
|
@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") (
|
|||
@Since("1.4.0")
|
||||
override def setStepSize(value: Double): this.type = set(stepSize, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
override def setFeatureSubsetStrategy(value: String): this.type =
|
||||
set(featureSubsetStrategy, value)
|
||||
|
||||
// Parameters from GBTClassifierParams:
|
||||
|
||||
/** @group setParam */
|
||||
|
@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") (
|
|||
val instr = Instrumentation.create(this, oldDataset)
|
||||
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
|
||||
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
|
||||
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
|
||||
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
|
||||
instr.logNumFeatures(numFeatures)
|
||||
instr.logNumClasses(numClasses)
|
||||
|
||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
||||
$(seed))
|
||||
$(seed), $(featureSubsetStrategy))
|
||||
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
|
||||
instr.logSuccess(m)
|
||||
m
|
||||
|
|
|
@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
|
|||
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
||||
@Since("1.4.0")
|
||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||
RandomForestParams.supportedFeatureSubsetStrategies
|
||||
TreeEnsembleParams.supportedFeatureSubsetStrategies
|
||||
|
||||
@Since("2.0.0")
|
||||
override def load(path: String): RandomForestClassifier = super.load(path)
|
||||
|
|
|
@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
}
|
||||
|
||||
/** (private[ml]) Train a decision tree on an RDD */
|
||||
private[ml] def train(data: RDD[LabeledPoint],
|
||||
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
|
||||
private[ml] def train(
|
||||
data: RDD[LabeledPoint],
|
||||
oldStrategy: OldStrategy,
|
||||
featureSubsetStrategy: String): DecisionTreeRegressionModel = {
|
||||
val instr = Instrumentation.create(this, data)
|
||||
instr.logParams(params: _*)
|
||||
|
||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
|
||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||
|
||||
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||
|
|
|
@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
|||
@Since("1.4.0")
|
||||
def setLossType(value: String): this.type = set(lossType, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.3.0")
|
||||
override def setFeatureSubsetStrategy(value: String): this.type =
|
||||
set(featureSubsetStrategy, value)
|
||||
|
||||
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
|
||||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
|
@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
|||
val instr = Instrumentation.create(this, oldDataset)
|
||||
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
|
||||
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
|
||||
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
|
||||
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
|
||||
instr.logNumFeatures(numFeatures)
|
||||
|
||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
||||
$(seed))
|
||||
$(seed), $(featureSubsetStrategy))
|
||||
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
|
||||
instr.logSuccess(m)
|
||||
m
|
||||
|
|
|
@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
|
|||
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
||||
@Since("1.4.0")
|
||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||
RandomForestParams.supportedFeatureSubsetStrategies
|
||||
TreeEnsembleParams.supportedFeatureSubsetStrategies
|
||||
|
||||
@Since("2.0.0")
|
||||
override def load(path: String): RandomForestRegressor = super.load(path)
|
||||
|
|
|
@ -22,7 +22,7 @@ import scala.util.Try
|
|||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.tree.RandomForestParams
|
||||
import org.apache.spark.ml.tree.TreeEnsembleParams
|
||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||
import org.apache.spark.mllib.tree.configuration.Strategy
|
||||
|
@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
|||
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
|
||||
case Some(value) => math.ceil(value * numFeatures).toInt
|
||||
case _ => throw new IllegalArgumentException(s"Supported values:" +
|
||||
s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
|
||||
s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
|
||||
s" (0.0-1.0], [1-n].")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
def run(
|
||||
input: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy,
|
||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
seed: Long,
|
||||
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
val algo = boostingStrategy.treeStrategy.algo
|
||||
algo match {
|
||||
case OldAlgo.Regression =>
|
||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
|
||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false,
|
||||
seed, featureSubsetStrategy)
|
||||
case OldAlgo.Classification =>
|
||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
|
||||
seed)
|
||||
seed, featureSubsetStrategy)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
|
||||
}
|
||||
|
@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
input: RDD[LabeledPoint],
|
||||
validationInput: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy,
|
||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
seed: Long,
|
||||
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
val algo = boostingStrategy.treeStrategy.algo
|
||||
algo match {
|
||||
case OldAlgo.Regression =>
|
||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
|
||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy,
|
||||
validate = true, seed, featureSubsetStrategy)
|
||||
case OldAlgo.Classification =>
|
||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||
val remappedInput = input.map(
|
||||
|
@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
val remappedValidationInput = validationInput.map(
|
||||
x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
|
||||
validate = true, seed)
|
||||
validate = true, seed, featureSubsetStrategy)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
|
||||
}
|
||||
|
@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
validationInput: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy,
|
||||
validate: Boolean,
|
||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
seed: Long,
|
||||
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
val timer = new TimeTracker()
|
||||
timer.start("total")
|
||||
timer.start("init")
|
||||
|
@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
val baseLearnerWeights = new Array[Double](numIterations)
|
||||
val loss = boostingStrategy.loss
|
||||
val learningRate = boostingStrategy.learningRate
|
||||
|
||||
// Prepare strategy for individual trees, which use regression with variance impurity.
|
||||
val treeStrategy = boostingStrategy.treeStrategy.copy
|
||||
val validationTol = boostingStrategy.validationTol
|
||||
|
@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
// Initialize tree
|
||||
timer.start("building tree 0")
|
||||
val firstTree = new DecisionTreeRegressor().setSeed(seed)
|
||||
val firstTreeModel = firstTree.train(input, treeStrategy)
|
||||
val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)
|
||||
val firstTreeWeight = 1.0
|
||||
baseLearners(0) = firstTreeModel
|
||||
baseLearnerWeights(0) = firstTreeWeight
|
||||
|
@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging {
|
|||
logDebug("###################################################")
|
||||
logDebug("Gradient boosting tree iteration " + m)
|
||||
logDebug("###################################################")
|
||||
|
||||
val dt = new DecisionTreeRegressor().setSeed(seed + m)
|
||||
val model = dt.train(data, treeStrategy)
|
||||
val model = dt.train(data, treeStrategy, featureSubsetStrategy)
|
||||
timer.stop(s"building tree $m")
|
||||
// Update partial model
|
||||
baseLearners(m) = model
|
||||
|
|
|
@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] object TreeEnsembleParams {
|
||||
// These options should be lowercase.
|
||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for Decision Tree-based ensemble algorithms.
|
||||
*
|
||||
|
@ -359,8 +365,58 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
|
|||
oldImpurity: OldImpurity): OldStrategy = {
|
||||
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of features to consider for splits at each tree node.
|
||||
* Supported options:
|
||||
* - "auto": Choose automatically for task:
|
||||
* If numTrees == 1, set to "all."
|
||||
* If numTrees > 1 (forest), set to "sqrt" for classification and
|
||||
* to "onethird" for regression.
|
||||
* - "all": use all features
|
||||
* - "onethird": use 1/3 of the features
|
||||
* - "sqrt": use sqrt(number of features)
|
||||
* - "log2": use log2(number of features)
|
||||
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
|
||||
* is in the range (1, number of features), use n features.
|
||||
* (default = "auto")
|
||||
*
|
||||
* These various settings are based on the following references:
|
||||
* - log2: tested in Breiman (2001)
|
||||
* - sqrt: recommended by Breiman manual for random forests
|
||||
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
|
||||
* package.
|
||||
* @see <a href="http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf">Breiman (2001)</a>
|
||||
* @see <a href="http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf">
|
||||
* Breiman manual for random forests</a>
|
||||
*
|
||||
* @group param
|
||||
*/
|
||||
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
|
||||
"The number of features to consider for splits at each tree node." +
|
||||
s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
|
||||
s", (0.0-1.0], [1-n].",
|
||||
(value: String) =>
|
||||
TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
|
||||
value.toLowerCase(Locale.ROOT))
|
||||
|| Try(value.toInt).filter(_ > 0).isSuccess
|
||||
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
|
||||
|
||||
setDefault(featureSubsetStrategy -> "auto")
|
||||
|
||||
/**
|
||||
* @deprecated This method is deprecated and will be removed in 3.0.0
|
||||
* @group setParam
|
||||
*/
|
||||
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
|
||||
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
|
||||
|
||||
/** @group getParam */
|
||||
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Parameters for Random Forest algorithms.
|
||||
*/
|
||||
|
@ -391,60 +447,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
|
|||
|
||||
/** @group getParam */
|
||||
final def getNumTrees: Int = $(numTrees)
|
||||
|
||||
/**
|
||||
* The number of features to consider for splits at each tree node.
|
||||
* Supported options:
|
||||
* - "auto": Choose automatically for task:
|
||||
* If numTrees == 1, set to "all."
|
||||
* If numTrees > 1 (forest), set to "sqrt" for classification and
|
||||
* to "onethird" for regression.
|
||||
* - "all": use all features
|
||||
* - "onethird": use 1/3 of the features
|
||||
* - "sqrt": use sqrt(number of features)
|
||||
* - "log2": use log2(number of features)
|
||||
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
|
||||
* is in the range (1, number of features), use n features.
|
||||
* (default = "auto")
|
||||
*
|
||||
* These various settings are based on the following references:
|
||||
* - log2: tested in Breiman (2001)
|
||||
* - sqrt: recommended by Breiman manual for random forests
|
||||
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
|
||||
* package.
|
||||
* @see <a href="http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf">Breiman (2001)</a>
|
||||
* @see <a href="http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf">
|
||||
* Breiman manual for random forests</a>
|
||||
*
|
||||
* @group param
|
||||
*/
|
||||
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
|
||||
"The number of features to consider for splits at each tree node." +
|
||||
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
|
||||
s", (0.0-1.0], [1-n].",
|
||||
(value: String) =>
|
||||
RandomForestParams.supportedFeatureSubsetStrategies.contains(
|
||||
value.toLowerCase(Locale.ROOT))
|
||||
|| Try(value.toInt).filter(_ > 0).isSuccess
|
||||
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
|
||||
|
||||
setDefault(featureSubsetStrategy -> "auto")
|
||||
|
||||
/**
|
||||
* @deprecated This method is deprecated and will be removed in 3.0.0.
|
||||
* @group setParam
|
||||
*/
|
||||
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
|
||||
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
|
||||
|
||||
/** @group getParam */
|
||||
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
|
||||
}
|
||||
|
||||
private[spark] object RandomForestParams {
|
||||
// These options should be lowercase.
|
||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
|
||||
}
|
||||
|
||||
private[ml] trait RandomForestClassifierParams
|
||||
|
@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
|
|||
|
||||
setDefault(maxIter -> 20, stepSize -> 0.1)
|
||||
|
||||
setDefault(featureSubsetStrategy -> "all")
|
||||
|
||||
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
|
||||
private[ml] def getOldBoostingStrategy(
|
||||
categoricalFeatures: Map[Int, Int],
|
||||
|
|
|
@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] (
|
|||
val algo = boostingStrategy.treeStrategy.algo
|
||||
val (trees, treeWeights) = NewGBT.run(input.map { point =>
|
||||
NewLabeledPoint(point.label, point.features.asML)
|
||||
}, boostingStrategy, seed.toLong)
|
||||
}, boostingStrategy, seed.toLong, "all")
|
||||
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
|
||||
}
|
||||
|
||||
|
@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] (
|
|||
NewLabeledPoint(point.label, point.features.asML)
|
||||
}, validationInput.map { point =>
|
||||
NewLabeledPoint(point.label, point.features.asML)
|
||||
}, boostingStrategy, seed.toLong)
|
||||
}, boostingStrategy, seed.toLong, "all")
|
||||
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import scala.util.Try
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams}
|
||||
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
|
||||
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||
|
|
|
@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
assert(gbt.getPredictionCol === "prediction")
|
||||
assert(gbt.getRawPredictionCol === "rawPrediction")
|
||||
assert(gbt.getProbabilityCol === "probability")
|
||||
assert(gbt.getFeatureSubsetStrategy === "all")
|
||||
val df = trainData.toDF()
|
||||
val model = gbt.fit(df)
|
||||
model.transform(df)
|
||||
|
@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
assert(model.getPredictionCol === "prediction")
|
||||
assert(model.getRawPredictionCol === "rawPrediction")
|
||||
assert(model.getProbabilityCol === "probability")
|
||||
assert(model.getFeatureSubsetStrategy === "all")
|
||||
assert(model.hasParent)
|
||||
|
||||
MLTestingUtils.checkCopyAndUids(gbt, model)
|
||||
|
@ -356,6 +358,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
assert(importances.toArray.forall(_ >= 0.0))
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of feature subset strategy
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
test("Tests of feature subset strategy") {
|
||||
val numClasses = 2
|
||||
val gbt = new GBTClassifier()
|
||||
.setSeed(123)
|
||||
.setMaxDepth(3)
|
||||
.setMaxIter(5)
|
||||
.setFeatureSubsetStrategy("all")
|
||||
|
||||
// In this data, feature 1 is very important.
|
||||
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
|
||||
val categoricalFeatures = Map.empty[Int, Int]
|
||||
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
||||
|
||||
val importances = gbt.fit(df).featureImportances
|
||||
val mostImportantFeature = importances.argmax
|
||||
assert(mostImportantFeature === 1)
|
||||
|
||||
// GBT with different featureSubsetStrategy
|
||||
val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
|
||||
val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
|
||||
val mostIF = importanceFeatures.argmax
|
||||
assert(mostImportantFeature !== mostIF)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -165,6 +165,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
assert(importances.toArray.forall(_ >= 0.0))
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of feature subset strategy
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
test("Tests of feature subset strategy") {
|
||||
val numClasses = 2
|
||||
val gbt = new GBTRegressor()
|
||||
.setMaxDepth(3)
|
||||
.setMaxIter(5)
|
||||
.setSeed(123)
|
||||
.setFeatureSubsetStrategy("all")
|
||||
|
||||
// In this data, feature 1 is very important.
|
||||
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
|
||||
val categoricalFeatures = Map.empty[Int, Int]
|
||||
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
||||
|
||||
val importances = gbt.fit(df).featureImportances
|
||||
val mostImportantFeature = importances.argmax
|
||||
assert(mostImportantFeature === 1)
|
||||
|
||||
// GBT with different featureSubsetStrategy
|
||||
val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
|
||||
val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
|
||||
val mostIF = importanceFeatures.argmax
|
||||
assert(mostImportantFeature !== mostIF)
|
||||
}
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
val boostingStrategy =
|
||||
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
|
||||
val (validateTrees, validateTreeWeights) = GradientBoostedTrees
|
||||
.runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
|
||||
.runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all")
|
||||
val numTrees = validateTrees.length
|
||||
assert(numTrees !== numIterations)
|
||||
|
||||
// Test that it performs better on the validation dataset.
|
||||
val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
|
||||
val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all")
|
||||
val (errorWithoutValidation, errorWithValidation) = {
|
||||
if (algo == Classification) {
|
||||
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
|
||||
|
|
Loading…
Reference in a new issue