[SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor
## What changes were proposed in this pull request? (Provided featureSubset Strategy to GBTClassifier a) Moved featureSubsetStrategy to TreeEnsembleParams b) Changed GBTClassifier to pass featureSubsetStrategy val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)) ## How was this patch tested? a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier b)Added test cases in GBTClassifierSuite and GBTRegressorSuite Author: Pralabh Kumar <pralabhkumar@gmail.com> Closes #18118 from pralabhkumar/develop.
This commit is contained in:
parent
28ab5bf597
commit
9b9827759a
|
@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample {
|
||||||
.setLabelCol("indexedLabel")
|
.setLabelCol("indexedLabel")
|
||||||
.setFeaturesCol("indexedFeatures")
|
.setFeaturesCol("indexedFeatures")
|
||||||
.setMaxIter(10)
|
.setMaxIter(10)
|
||||||
|
.setFeatureSubsetStrategy("auto")
|
||||||
|
|
||||||
// Convert indexed labels back to original labels.
|
// Convert indexed labels back to original labels.
|
||||||
val labelConverter = new IndexToString()
|
val labelConverter = new IndexToString()
|
||||||
|
|
|
@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") (
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def setStepSize(value: Double): this.type = set(stepSize, value)
|
override def setStepSize(value: Double): this.type = set(stepSize, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
@Since("2.3.0")
|
||||||
|
override def setFeatureSubsetStrategy(value: String): this.type =
|
||||||
|
set(featureSubsetStrategy, value)
|
||||||
|
|
||||||
// Parameters from GBTClassifierParams:
|
// Parameters from GBTClassifierParams:
|
||||||
|
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
|
@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") (
|
||||||
val instr = Instrumentation.create(this, oldDataset)
|
val instr = Instrumentation.create(this, oldDataset)
|
||||||
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
|
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
|
||||||
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
|
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
|
||||||
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
|
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
|
||||||
instr.logNumFeatures(numFeatures)
|
instr.logNumFeatures(numFeatures)
|
||||||
instr.logNumClasses(numClasses)
|
instr.logNumClasses(numClasses)
|
||||||
|
|
||||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
||||||
$(seed))
|
$(seed), $(featureSubsetStrategy))
|
||||||
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
|
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
|
||||||
instr.logSuccess(m)
|
instr.logSuccess(m)
|
||||||
m
|
m
|
||||||
|
|
|
@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
|
||||||
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||||
RandomForestParams.supportedFeatureSubsetStrategies
|
TreeEnsembleParams.supportedFeatureSubsetStrategies
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def load(path: String): RandomForestClassifier = super.load(path)
|
override def load(path: String): RandomForestClassifier = super.load(path)
|
||||||
|
|
|
@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
||||||
}
|
}
|
||||||
|
|
||||||
/** (private[ml]) Train a decision tree on an RDD */
|
/** (private[ml]) Train a decision tree on an RDD */
|
||||||
private[ml] def train(data: RDD[LabeledPoint],
|
private[ml] def train(
|
||||||
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
|
data: RDD[LabeledPoint],
|
||||||
|
oldStrategy: OldStrategy,
|
||||||
|
featureSubsetStrategy: String): DecisionTreeRegressionModel = {
|
||||||
val instr = Instrumentation.create(this, data)
|
val instr = Instrumentation.create(this, data)
|
||||||
instr.logParams(params: _*)
|
instr.logParams(params: _*)
|
||||||
|
|
||||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
|
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
|
||||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||||
|
|
||||||
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||||
|
|
|
@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
def setLossType(value: String): this.type = set(lossType, value)
|
def setLossType(value: String): this.type = set(lossType, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
@Since("2.3.0")
|
||||||
|
override def setFeatureSubsetStrategy(value: String): this.type =
|
||||||
|
set(featureSubsetStrategy, value)
|
||||||
|
|
||||||
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
|
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
|
@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
val instr = Instrumentation.create(this, oldDataset)
|
val instr = Instrumentation.create(this, oldDataset)
|
||||||
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
|
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
|
||||||
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
|
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
|
||||||
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
|
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
|
||||||
instr.logNumFeatures(numFeatures)
|
instr.logNumFeatures(numFeatures)
|
||||||
|
|
||||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
||||||
$(seed))
|
$(seed), $(featureSubsetStrategy))
|
||||||
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
|
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
|
||||||
instr.logSuccess(m)
|
instr.logSuccess(m)
|
||||||
m
|
m
|
||||||
|
|
|
@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
|
||||||
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||||
RandomForestParams.supportedFeatureSubsetStrategies
|
TreeEnsembleParams.supportedFeatureSubsetStrategies
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def load(path: String): RandomForestRegressor = super.load(path)
|
override def load(path: String): RandomForestRegressor = super.load(path)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import scala.util.Try
|
||||||
|
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.tree.RandomForestParams
|
import org.apache.spark.ml.tree.TreeEnsembleParams
|
||||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||||
import org.apache.spark.mllib.tree.configuration.Strategy
|
import org.apache.spark.mllib.tree.configuration.Strategy
|
||||||
|
@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
||||||
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
|
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
|
||||||
case Some(value) => math.ceil(value * numFeatures).toInt
|
case Some(value) => math.ceil(value * numFeatures).toInt
|
||||||
case _ => throw new IllegalArgumentException(s"Supported values:" +
|
case _ => throw new IllegalArgumentException(s"Supported values:" +
|
||||||
s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
|
s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
|
||||||
s" (0.0-1.0], [1-n].")
|
s" (0.0-1.0], [1-n].")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
def run(
|
def run(
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[LabeledPoint],
|
||||||
boostingStrategy: OldBoostingStrategy,
|
boostingStrategy: OldBoostingStrategy,
|
||||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
seed: Long,
|
||||||
|
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||||
val algo = boostingStrategy.treeStrategy.algo
|
val algo = boostingStrategy.treeStrategy.algo
|
||||||
algo match {
|
algo match {
|
||||||
case OldAlgo.Regression =>
|
case OldAlgo.Regression =>
|
||||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
|
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false,
|
||||||
|
seed, featureSubsetStrategy)
|
||||||
case OldAlgo.Classification =>
|
case OldAlgo.Classification =>
|
||||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||||
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
|
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||||
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
|
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
|
||||||
seed)
|
seed, featureSubsetStrategy)
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
|
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
|
||||||
}
|
}
|
||||||
|
@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[LabeledPoint],
|
||||||
validationInput: RDD[LabeledPoint],
|
validationInput: RDD[LabeledPoint],
|
||||||
boostingStrategy: OldBoostingStrategy,
|
boostingStrategy: OldBoostingStrategy,
|
||||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
seed: Long,
|
||||||
|
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||||
val algo = boostingStrategy.treeStrategy.algo
|
val algo = boostingStrategy.treeStrategy.algo
|
||||||
algo match {
|
algo match {
|
||||||
case OldAlgo.Regression =>
|
case OldAlgo.Regression =>
|
||||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
|
GradientBoostedTrees.boost(input, validationInput, boostingStrategy,
|
||||||
|
validate = true, seed, featureSubsetStrategy)
|
||||||
case OldAlgo.Classification =>
|
case OldAlgo.Classification =>
|
||||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||||
val remappedInput = input.map(
|
val remappedInput = input.map(
|
||||||
|
@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
val remappedValidationInput = validationInput.map(
|
val remappedValidationInput = validationInput.map(
|
||||||
x => new LabeledPoint((x.label * 2) - 1, x.features))
|
x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||||
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
|
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
|
||||||
validate = true, seed)
|
validate = true, seed, featureSubsetStrategy)
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
|
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
|
||||||
}
|
}
|
||||||
|
@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
validationInput: RDD[LabeledPoint],
|
validationInput: RDD[LabeledPoint],
|
||||||
boostingStrategy: OldBoostingStrategy,
|
boostingStrategy: OldBoostingStrategy,
|
||||||
validate: Boolean,
|
validate: Boolean,
|
||||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
seed: Long,
|
||||||
|
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||||
val timer = new TimeTracker()
|
val timer = new TimeTracker()
|
||||||
timer.start("total")
|
timer.start("total")
|
||||||
timer.start("init")
|
timer.start("init")
|
||||||
|
@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
val baseLearnerWeights = new Array[Double](numIterations)
|
val baseLearnerWeights = new Array[Double](numIterations)
|
||||||
val loss = boostingStrategy.loss
|
val loss = boostingStrategy.loss
|
||||||
val learningRate = boostingStrategy.learningRate
|
val learningRate = boostingStrategy.learningRate
|
||||||
|
|
||||||
// Prepare strategy for individual trees, which use regression with variance impurity.
|
// Prepare strategy for individual trees, which use regression with variance impurity.
|
||||||
val treeStrategy = boostingStrategy.treeStrategy.copy
|
val treeStrategy = boostingStrategy.treeStrategy.copy
|
||||||
val validationTol = boostingStrategy.validationTol
|
val validationTol = boostingStrategy.validationTol
|
||||||
|
@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
// Initialize tree
|
// Initialize tree
|
||||||
timer.start("building tree 0")
|
timer.start("building tree 0")
|
||||||
val firstTree = new DecisionTreeRegressor().setSeed(seed)
|
val firstTree = new DecisionTreeRegressor().setSeed(seed)
|
||||||
val firstTreeModel = firstTree.train(input, treeStrategy)
|
val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)
|
||||||
val firstTreeWeight = 1.0
|
val firstTreeWeight = 1.0
|
||||||
baseLearners(0) = firstTreeModel
|
baseLearners(0) = firstTreeModel
|
||||||
baseLearnerWeights(0) = firstTreeWeight
|
baseLearnerWeights(0) = firstTreeWeight
|
||||||
|
@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging {
|
||||||
logDebug("###################################################")
|
logDebug("###################################################")
|
||||||
logDebug("Gradient boosting tree iteration " + m)
|
logDebug("Gradient boosting tree iteration " + m)
|
||||||
logDebug("###################################################")
|
logDebug("###################################################")
|
||||||
|
|
||||||
val dt = new DecisionTreeRegressor().setSeed(seed + m)
|
val dt = new DecisionTreeRegressor().setSeed(seed + m)
|
||||||
val model = dt.train(data, treeStrategy)
|
val model = dt.train(data, treeStrategy, featureSubsetStrategy)
|
||||||
timer.stop(s"building tree $m")
|
timer.stop(s"building tree $m")
|
||||||
// Update partial model
|
// Update partial model
|
||||||
baseLearners(m) = model
|
baseLearners(m) = model
|
||||||
|
|
|
@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private[spark] object TreeEnsembleParams {
|
||||||
|
// These options should be lowercase.
|
||||||
|
final val supportedFeatureSubsetStrategies: Array[String] =
|
||||||
|
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameters for Decision Tree-based ensemble algorithms.
|
* Parameters for Decision Tree-based ensemble algorithms.
|
||||||
*
|
*
|
||||||
|
@ -359,8 +365,58 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
|
||||||
oldImpurity: OldImpurity): OldStrategy = {
|
oldImpurity: OldImpurity): OldStrategy = {
|
||||||
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
|
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of features to consider for splits at each tree node.
|
||||||
|
* Supported options:
|
||||||
|
* - "auto": Choose automatically for task:
|
||||||
|
* If numTrees == 1, set to "all."
|
||||||
|
* If numTrees > 1 (forest), set to "sqrt" for classification and
|
||||||
|
* to "onethird" for regression.
|
||||||
|
* - "all": use all features
|
||||||
|
* - "onethird": use 1/3 of the features
|
||||||
|
* - "sqrt": use sqrt(number of features)
|
||||||
|
* - "log2": use log2(number of features)
|
||||||
|
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
|
||||||
|
* is in the range (1, number of features), use n features.
|
||||||
|
* (default = "auto")
|
||||||
|
*
|
||||||
|
* These various settings are based on the following references:
|
||||||
|
* - log2: tested in Breiman (2001)
|
||||||
|
* - sqrt: recommended by Breiman manual for random forests
|
||||||
|
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
|
||||||
|
* package.
|
||||||
|
* @see <a href="http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf">Breiman (2001)</a>
|
||||||
|
* @see <a href="http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf">
|
||||||
|
* Breiman manual for random forests</a>
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
|
||||||
|
"The number of features to consider for splits at each tree node." +
|
||||||
|
s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
|
||||||
|
s", (0.0-1.0], [1-n].",
|
||||||
|
(value: String) =>
|
||||||
|
TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
|
||||||
|
value.toLowerCase(Locale.ROOT))
|
||||||
|
|| Try(value.toInt).filter(_ > 0).isSuccess
|
||||||
|
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
|
||||||
|
|
||||||
|
setDefault(featureSubsetStrategy -> "auto")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated This method is deprecated and will be removed in 3.0.0
|
||||||
|
* @group setParam
|
||||||
|
*/
|
||||||
|
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
|
||||||
|
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameters for Random Forest algorithms.
|
* Parameters for Random Forest algorithms.
|
||||||
*/
|
*/
|
||||||
|
@ -391,60 +447,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
|
||||||
|
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
final def getNumTrees: Int = $(numTrees)
|
final def getNumTrees: Int = $(numTrees)
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of features to consider for splits at each tree node.
|
|
||||||
* Supported options:
|
|
||||||
* - "auto": Choose automatically for task:
|
|
||||||
* If numTrees == 1, set to "all."
|
|
||||||
* If numTrees > 1 (forest), set to "sqrt" for classification and
|
|
||||||
* to "onethird" for regression.
|
|
||||||
* - "all": use all features
|
|
||||||
* - "onethird": use 1/3 of the features
|
|
||||||
* - "sqrt": use sqrt(number of features)
|
|
||||||
* - "log2": use log2(number of features)
|
|
||||||
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
|
|
||||||
* is in the range (1, number of features), use n features.
|
|
||||||
* (default = "auto")
|
|
||||||
*
|
|
||||||
* These various settings are based on the following references:
|
|
||||||
* - log2: tested in Breiman (2001)
|
|
||||||
* - sqrt: recommended by Breiman manual for random forests
|
|
||||||
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
|
|
||||||
* package.
|
|
||||||
* @see <a href="http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf">Breiman (2001)</a>
|
|
||||||
* @see <a href="http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf">
|
|
||||||
* Breiman manual for random forests</a>
|
|
||||||
*
|
|
||||||
* @group param
|
|
||||||
*/
|
|
||||||
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
|
|
||||||
"The number of features to consider for splits at each tree node." +
|
|
||||||
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
|
|
||||||
s", (0.0-1.0], [1-n].",
|
|
||||||
(value: String) =>
|
|
||||||
RandomForestParams.supportedFeatureSubsetStrategies.contains(
|
|
||||||
value.toLowerCase(Locale.ROOT))
|
|
||||||
|| Try(value.toInt).filter(_ > 0).isSuccess
|
|
||||||
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
|
|
||||||
|
|
||||||
setDefault(featureSubsetStrategy -> "auto")
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated This method is deprecated and will be removed in 3.0.0.
|
|
||||||
* @group setParam
|
|
||||||
*/
|
|
||||||
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
|
|
||||||
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
|
|
||||||
|
|
||||||
/** @group getParam */
|
|
||||||
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] object RandomForestParams {
|
|
||||||
// These options should be lowercase.
|
|
||||||
final val supportedFeatureSubsetStrategies: Array[String] =
|
|
||||||
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private[ml] trait RandomForestClassifierParams
|
private[ml] trait RandomForestClassifierParams
|
||||||
|
@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
|
||||||
|
|
||||||
setDefault(maxIter -> 20, stepSize -> 0.1)
|
setDefault(maxIter -> 20, stepSize -> 0.1)
|
||||||
|
|
||||||
|
setDefault(featureSubsetStrategy -> "all")
|
||||||
|
|
||||||
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
|
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
|
||||||
private[ml] def getOldBoostingStrategy(
|
private[ml] def getOldBoostingStrategy(
|
||||||
categoricalFeatures: Map[Int, Int],
|
categoricalFeatures: Map[Int, Int],
|
||||||
|
|
|
@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] (
|
||||||
val algo = boostingStrategy.treeStrategy.algo
|
val algo = boostingStrategy.treeStrategy.algo
|
||||||
val (trees, treeWeights) = NewGBT.run(input.map { point =>
|
val (trees, treeWeights) = NewGBT.run(input.map { point =>
|
||||||
NewLabeledPoint(point.label, point.features.asML)
|
NewLabeledPoint(point.label, point.features.asML)
|
||||||
}, boostingStrategy, seed.toLong)
|
}, boostingStrategy, seed.toLong, "all")
|
||||||
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
|
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] (
|
||||||
NewLabeledPoint(point.label, point.features.asML)
|
NewLabeledPoint(point.label, point.features.asML)
|
||||||
}, validationInput.map { point =>
|
}, validationInput.map { point =>
|
||||||
NewLabeledPoint(point.label, point.features.asML)
|
NewLabeledPoint(point.label, point.features.asML)
|
||||||
}, boostingStrategy, seed.toLong)
|
}, boostingStrategy, seed.toLong, "all")
|
||||||
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
|
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ import scala.util.Try
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.api.java.JavaRDD
|
import org.apache.spark.api.java.JavaRDD
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams}
|
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
|
||||||
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
|
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||||
|
|
|
@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
assert(gbt.getPredictionCol === "prediction")
|
assert(gbt.getPredictionCol === "prediction")
|
||||||
assert(gbt.getRawPredictionCol === "rawPrediction")
|
assert(gbt.getRawPredictionCol === "rawPrediction")
|
||||||
assert(gbt.getProbabilityCol === "probability")
|
assert(gbt.getProbabilityCol === "probability")
|
||||||
|
assert(gbt.getFeatureSubsetStrategy === "all")
|
||||||
val df = trainData.toDF()
|
val df = trainData.toDF()
|
||||||
val model = gbt.fit(df)
|
val model = gbt.fit(df)
|
||||||
model.transform(df)
|
model.transform(df)
|
||||||
|
@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
assert(model.getPredictionCol === "prediction")
|
assert(model.getPredictionCol === "prediction")
|
||||||
assert(model.getRawPredictionCol === "rawPrediction")
|
assert(model.getRawPredictionCol === "rawPrediction")
|
||||||
assert(model.getProbabilityCol === "probability")
|
assert(model.getProbabilityCol === "probability")
|
||||||
|
assert(model.getFeatureSubsetStrategy === "all")
|
||||||
assert(model.hasParent)
|
assert(model.hasParent)
|
||||||
|
|
||||||
MLTestingUtils.checkCopyAndUids(gbt, model)
|
MLTestingUtils.checkCopyAndUids(gbt, model)
|
||||||
|
@ -356,6 +358,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
assert(importances.toArray.forall(_ >= 0.0))
|
assert(importances.toArray.forall(_ >= 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Tests of feature subset strategy
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
test("Tests of feature subset strategy") {
|
||||||
|
val numClasses = 2
|
||||||
|
val gbt = new GBTClassifier()
|
||||||
|
.setSeed(123)
|
||||||
|
.setMaxDepth(3)
|
||||||
|
.setMaxIter(5)
|
||||||
|
.setFeatureSubsetStrategy("all")
|
||||||
|
|
||||||
|
// In this data, feature 1 is very important.
|
||||||
|
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
|
||||||
|
val categoricalFeatures = Map.empty[Int, Int]
|
||||||
|
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
||||||
|
|
||||||
|
val importances = gbt.fit(df).featureImportances
|
||||||
|
val mostImportantFeature = importances.argmax
|
||||||
|
assert(mostImportantFeature === 1)
|
||||||
|
|
||||||
|
// GBT with different featureSubsetStrategy
|
||||||
|
val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
|
||||||
|
val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
|
||||||
|
val mostIF = importanceFeatures.argmax
|
||||||
|
assert(mostImportantFeature !== mostIF)
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of model save/load
|
// Tests of model save/load
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -165,6 +165,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
assert(importances.toArray.forall(_ >= 0.0))
|
assert(importances.toArray.forall(_ >= 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Tests of feature subset strategy
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
test("Tests of feature subset strategy") {
|
||||||
|
val numClasses = 2
|
||||||
|
val gbt = new GBTRegressor()
|
||||||
|
.setMaxDepth(3)
|
||||||
|
.setMaxIter(5)
|
||||||
|
.setSeed(123)
|
||||||
|
.setFeatureSubsetStrategy("all")
|
||||||
|
|
||||||
|
// In this data, feature 1 is very important.
|
||||||
|
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
|
||||||
|
val categoricalFeatures = Map.empty[Int, Int]
|
||||||
|
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
||||||
|
|
||||||
|
val importances = gbt.fit(df).featureImportances
|
||||||
|
val mostImportantFeature = importances.argmax
|
||||||
|
assert(mostImportantFeature === 1)
|
||||||
|
|
||||||
|
// GBT with different featureSubsetStrategy
|
||||||
|
val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
|
||||||
|
val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
|
||||||
|
val mostIF = importanceFeatures.argmax
|
||||||
|
assert(mostImportantFeature !== mostIF)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of model save/load
|
// Tests of model save/load
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||||
val boostingStrategy =
|
val boostingStrategy =
|
||||||
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
|
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
|
||||||
val (validateTrees, validateTreeWeights) = GradientBoostedTrees
|
val (validateTrees, validateTreeWeights) = GradientBoostedTrees
|
||||||
.runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
|
.runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all")
|
||||||
val numTrees = validateTrees.length
|
val numTrees = validateTrees.length
|
||||||
assert(numTrees !== numIterations)
|
assert(numTrees !== numIterations)
|
||||||
|
|
||||||
// Test that it performs better on the validation dataset.
|
// Test that it performs better on the validation dataset.
|
||||||
val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
|
val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all")
|
||||||
val (errorWithoutValidation, errorWithValidation) = {
|
val (errorWithoutValidation, errorWithValidation) = {
|
||||||
if (algo == Classification) {
|
if (algo == Classification) {
|
||||||
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
|
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
|
||||||
|
|
Loading…
Reference in a new issue