[SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor

## What changes were proposed in this pull request?

(Provided featureSubset Strategy to GBTClassifier
a) Moved featureSubsetStrategy to TreeEnsembleParams
b)  Changed GBTClassifier to pass featureSubsetStrategy
val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy))

## How was this patch tested?
a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier

b)Added test cases in GBTClassifierSuite and GBTRegressorSuite

Author: Pralabh Kumar <pralabhkumar@gmail.com>

Closes #18118 from pralabhkumar/develop.
This commit is contained in:
Pralabh Kumar 2017-11-10 13:17:25 +02:00 committed by Nick Pentreath
parent 28ab5bf597
commit 9b9827759a
14 changed files with 161 additions and 79 deletions

View file

@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample {
.setLabelCol("indexedLabel") .setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures") .setFeaturesCol("indexedFeatures")
.setMaxIter(10) .setMaxIter(10)
.setFeatureSubsetStrategy("auto")
// Convert indexed labels back to original labels. // Convert indexed labels back to original labels.
val labelConverter = new IndexToString() val labelConverter = new IndexToString()

View file

@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") @Since("1.4.0")
override def setStepSize(value: Double): this.type = set(stepSize, value) override def setStepSize(value: Double): this.type = set(stepSize, value)
/** @group setParam */
@Since("2.3.0")
override def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
// Parameters from GBTClassifierParams: // Parameters from GBTClassifierParams:
/** @group setParam */ /** @group setParam */
@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") (
val instr = Instrumentation.create(this, oldDataset) val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
instr.logNumFeatures(numFeatures) instr.logNumFeatures(numFeatures)
instr.logNumClasses(numClasses) instr.logNumClasses(numClasses)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed)) $(seed), $(featureSubsetStrategy))
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m) instr.logSuccess(m)
m m

View file

@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0") @Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] = final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies TreeEnsembleParams.supportedFeatureSubsetStrategies
@Since("2.0.0") @Since("2.0.0")
override def load(path: String): RandomForestClassifier = super.load(path) override def load(path: String): RandomForestClassifier = super.load(path)

View file

@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
} }
/** (private[ml]) Train a decision tree on an RDD */ /** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint], private[ml] def train(
oldStrategy: OldStrategy): DecisionTreeRegressionModel = { data: RDD[LabeledPoint],
oldStrategy: OldStrategy,
featureSubsetStrategy: String): DecisionTreeRegressionModel = {
val instr = Instrumentation.create(this, data) val instr = Instrumentation.create(this, data)
instr.logParams(params: _*) instr.logParams(params: _*)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
seed = $(seed), instr = Some(instr), parentUID = Some(uid)) seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]

View file

@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0") @Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value) def setLossType(value: String): this.type = set(lossType, value)
/** @group setParam */
@Since("2.3.0")
override def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
override protected def train(dataset: Dataset[_]): GBTRegressionModel = { override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] = val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val instr = Instrumentation.create(this, oldDataset) val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
instr.logNumFeatures(numFeatures) instr.logNumFeatures(numFeatures)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed)) $(seed), $(featureSubsetStrategy))
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m) instr.logSuccess(m)
m m

View file

@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0") @Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] = final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies TreeEnsembleParams.supportedFeatureSubsetStrategies
@Since("2.0.0") @Since("2.0.0")
override def load(path: String): RandomForestRegressor = super.load(path) override def load(path: String): RandomForestRegressor = super.load(path)

View file

@ -22,7 +22,7 @@ import scala.util.Try
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.tree.RandomForestParams import org.apache.spark.ml.tree.TreeEnsembleParams
import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Strategy
@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
case Some(value) => math.ceil(value * numFeatures).toInt case Some(value) => math.ceil(value * numFeatures).toInt
case _ => throw new IllegalArgumentException(s"Supported values:" + case _ => throw new IllegalArgumentException(s"Supported values:" +
s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].") s" (0.0-1.0], [1-n].")
} }
} }

View file

@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging {
def run( def run(
input: RDD[LabeledPoint], input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy, boostingStrategy: OldBoostingStrategy,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { seed: Long,
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo val algo = boostingStrategy.treeStrategy.algo
algo match { algo match {
case OldAlgo.Regression => case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false,
seed, featureSubsetStrategy)
case OldAlgo.Classification => case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression. // Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
seed) seed, featureSubsetStrategy)
case _ => case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
} }
@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging {
input: RDD[LabeledPoint], input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint], validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy, boostingStrategy: OldBoostingStrategy,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { seed: Long,
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo val algo = boostingStrategy.treeStrategy.algo
algo match { algo match {
case OldAlgo.Regression => case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) GradientBoostedTrees.boost(input, validationInput, boostingStrategy,
validate = true, seed, featureSubsetStrategy)
case OldAlgo.Classification => case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression. // Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map( val remappedInput = input.map(
@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging {
val remappedValidationInput = validationInput.map( val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features)) x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate = true, seed) validate = true, seed, featureSubsetStrategy)
case _ => case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
} }
@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging {
validationInput: RDD[LabeledPoint], validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy, boostingStrategy: OldBoostingStrategy,
validate: Boolean, validate: Boolean,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { seed: Long,
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker() val timer = new TimeTracker()
timer.start("total") timer.start("total")
timer.start("init") timer.start("init")
@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging {
val baseLearnerWeights = new Array[Double](numIterations) val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity. // Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol val validationTol = boostingStrategy.validationTol
@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging {
// Initialize tree // Initialize tree
timer.start("building tree 0") timer.start("building tree 0")
val firstTree = new DecisionTreeRegressor().setSeed(seed) val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy) val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)
val firstTreeWeight = 1.0 val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight baseLearnerWeights(0) = firstTreeWeight
@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging {
logDebug("###################################################") logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m) logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################") logDebug("###################################################")
val dt = new DecisionTreeRegressor().setSeed(seed + m) val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeStrategy) val model = dt.train(data, treeStrategy, featureSubsetStrategy)
timer.stop(s"building tree $m") timer.stop(s"building tree $m")
// Update partial model // Update partial model
baseLearners(m) = model baseLearners(m) = model

View file

@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
} }
} }
private[spark] object TreeEnsembleParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
}
/** /**
* Parameters for Decision Tree-based ensemble algorithms. * Parameters for Decision Tree-based ensemble algorithms.
* *
@ -359,8 +365,58 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
oldImpurity: OldImpurity): OldStrategy = { oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
} }
/**
* The number of features to consider for splits at each tree node.
* Supported options:
* - "auto": Choose automatically for task:
* If numTrees == 1, set to "all."
* If numTrees > 1 (forest), set to "sqrt" for classification and
* to "onethird" for regression.
* - "all": use all features
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
* is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
* - log2: tested in Breiman (2001)
* - sqrt: recommended by Breiman manual for random forests
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
* package.
* @see <a href="http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf">Breiman (2001)</a>
* @see <a href="http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf">
* Breiman manual for random forests</a>
*
* @group param
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
s", (0.0-1.0], [1-n].",
(value: String) =>
TreeEnsembleParams.supportedFeatureSubsetStrategies.contains(
value.toLowerCase(Locale.ROOT))
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
setDefault(featureSubsetStrategy -> "auto")
/**
* @deprecated This method is deprecated and will be removed in 3.0.0
* @group setParam
*/
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
} }
/** /**
* Parameters for Random Forest algorithms. * Parameters for Random Forest algorithms.
*/ */
@ -391,60 +447,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
/** @group getParam */ /** @group getParam */
final def getNumTrees: Int = $(numTrees) final def getNumTrees: Int = $(numTrees)
/**
* The number of features to consider for splits at each tree node.
* Supported options:
* - "auto": Choose automatically for task:
* If numTrees == 1, set to "all."
* If numTrees > 1 (forest), set to "sqrt" for classification and
* to "onethird" for regression.
* - "all": use all features
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
* is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
* - log2: tested in Breiman (2001)
* - sqrt: recommended by Breiman manual for random forests
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
* package.
* @see <a href="http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf">Breiman (2001)</a>
* @see <a href="http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf">
* Breiman manual for random forests</a>
*
* @group param
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
s", (0.0-1.0], [1-n].",
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(
value.toLowerCase(Locale.ROOT))
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
setDefault(featureSubsetStrategy -> "auto")
/**
* @deprecated This method is deprecated and will be removed in 3.0.0.
* @group setParam
*/
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT)
}
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT))
} }
private[ml] trait RandomForestClassifierParams private[ml] trait RandomForestClassifierParams
@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
setDefault(maxIter -> 20, stepSize -> 0.1) setDefault(maxIter -> 20, stepSize -> 0.1)
setDefault(featureSubsetStrategy -> "all")
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy( private[ml] def getOldBoostingStrategy(
categoricalFeatures: Map[Int, Int], categoricalFeatures: Map[Int, Int],

View file

@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] (
val algo = boostingStrategy.treeStrategy.algo val algo = boostingStrategy.treeStrategy.algo
val (trees, treeWeights) = NewGBT.run(input.map { point => val (trees, treeWeights) = NewGBT.run(input.map { point =>
NewLabeledPoint(point.label, point.features.asML) NewLabeledPoint(point.label, point.features.asML)
}, boostingStrategy, seed.toLong) }, boostingStrategy, seed.toLong, "all")
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
} }
@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] (
NewLabeledPoint(point.label, point.features.asML) NewLabeledPoint(point.label, point.features.asML)
}, validationInput.map { point => }, validationInput.map { point =>
NewLabeledPoint(point.label, point.features.asML) NewLabeledPoint(point.label, point.features.asML)
}, boostingStrategy, seed.toLong) }, boostingStrategy, seed.toLong, "all")
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
} }

View file

@ -23,7 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Algo._

View file

@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(gbt.getPredictionCol === "prediction") assert(gbt.getPredictionCol === "prediction")
assert(gbt.getRawPredictionCol === "rawPrediction") assert(gbt.getRawPredictionCol === "rawPrediction")
assert(gbt.getProbabilityCol === "probability") assert(gbt.getProbabilityCol === "probability")
assert(gbt.getFeatureSubsetStrategy === "all")
val df = trainData.toDF() val df = trainData.toDF()
val model = gbt.fit(df) val model = gbt.fit(df)
model.transform(df) model.transform(df)
@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.getPredictionCol === "prediction") assert(model.getPredictionCol === "prediction")
assert(model.getRawPredictionCol === "rawPrediction") assert(model.getRawPredictionCol === "rawPrediction")
assert(model.getProbabilityCol === "probability") assert(model.getProbabilityCol === "probability")
assert(model.getFeatureSubsetStrategy === "all")
assert(model.hasParent) assert(model.hasParent)
MLTestingUtils.checkCopyAndUids(gbt, model) MLTestingUtils.checkCopyAndUids(gbt, model)
@ -356,6 +358,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(importances.toArray.forall(_ >= 0.0)) assert(importances.toArray.forall(_ >= 0.0))
} }
/////////////////////////////////////////////////////////////////////////////
// Tests of feature subset strategy
/////////////////////////////////////////////////////////////////////////////
test("Tests of feature subset strategy") {
val numClasses = 2
val gbt = new GBTClassifier()
.setSeed(123)
.setMaxDepth(3)
.setMaxIter(5)
.setFeatureSubsetStrategy("all")
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val importances = gbt.fit(df).featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
// GBT with different featureSubsetStrategy
val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
val mostIF = importanceFeatures.argmax
assert(mostImportantFeature !== mostIF)
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load // Tests of model save/load
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////

View file

@ -165,6 +165,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(importances.toArray.forall(_ >= 0.0)) assert(importances.toArray.forall(_ >= 0.0))
} }
/////////////////////////////////////////////////////////////////////////////
// Tests of feature subset strategy
/////////////////////////////////////////////////////////////////////////////
test("Tests of feature subset strategy") {
val numClasses = 2
val gbt = new GBTRegressor()
.setMaxDepth(3)
.setMaxIter(5)
.setSeed(123)
.setFeatureSubsetStrategy("all")
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val importances = gbt.fit(df).featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
// GBT with different featureSubsetStrategy
val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
val mostIF = importanceFeatures.argmax
assert(mostImportantFeature !== mostIF)
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load // Tests of model save/load
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////

View file

@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val boostingStrategy = val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val (validateTrees, validateTreeWeights) = GradientBoostedTrees val (validateTrees, validateTreeWeights) = GradientBoostedTrees
.runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L) .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all")
val numTrees = validateTrees.length val numTrees = validateTrees.length
assert(numTrees !== numIterations) assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset. // Test that it performs better on the validation dataset.
val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L) val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all")
val (errorWithoutValidation, errorWithValidation) = { val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) { if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))