[SPARK-13952][ML] Add random seed to GBT

## What changes were proposed in this pull request?

`GBTClassifier` and `GBTRegressor` should use random seed for reproducible results. Because of the nature of current unit tests, which compare GBTs in ML and GBTs in MLlib for equality, I also added a random seed to MLlib GBT algorithm. I made alternate constructors in `mllib.tree.GradientBoostedTrees` to accept a random seed, but left them as private so as to not change the API unnecessarily.

## How was this patch tested?

Existing unit tests verify that functionality did not change. Other ML algorithms do not seem to have unit tests that directly test the functionality of random seeding, but reproducibility with seeding for GBTs is effectively verified in existing tests. I can add more tests if needed.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #11903 from sethah/SPARK-13952.
This commit is contained in:
sethah 2016-03-23 15:08:47 -07:00 committed by Joseph K. Bradley
parent 5dfc01976b
commit 69bc2c17f1
9 changed files with 66 additions and 39 deletions

View file

@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") (
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
override def setSeed(value: Long): this.type = {
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
super.setSeed(value)
}
override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from GBTParams:
@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed))
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
}

View file

@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}

View file

@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
override def setSeed(value: Long): this.type = {
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
super.setSeed(value)
}
override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from GBTParams:
@Since("1.4.0")
@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed))
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
}

View file

@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def run(input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
def run(
input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging {
* but it should follow the same distribution.
* E.g., these two datasets could be created from an original dataset
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def runWithValidation(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
boostingStrategy: OldBoostingStrategy,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging {
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate = true)
validate = true, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging {
* @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
* @param validate whether or not to use the validation dataset.
* @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging {
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
validate: Boolean,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
val firstTree = new DecisionTreeRegressor()
val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val dt = new DecisionTreeRegressor()
val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeStrategy)
timer.stop(s"building tree $m")
// Update partial model

View file

@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
* @param seed Random seed.
*/
@Since("1.0.0")
class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
extends Serializable with Logging {
/**
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
*/
@Since("1.0.0")
def this(strategy: Strategy) = this(strategy, seed = 0)
strategy.assertValid()
/**
@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = seed)
val rfModel = rf.run(input)
rfModel.trees(0)
}

View file

@ -47,11 +47,20 @@ import org.apache.spark.storage.StorageLevel
* for other loss functions.
*
* @param boostingStrategy Parameters for the gradient boosting algorithm.
* @param seed Random seed.
*/
@Since("1.2.0")
class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
class GradientBoostedTrees private[spark] (
private val boostingStrategy: BoostingStrategy,
private val seed: Int)
extends Serializable with Logging {
/**
* @param boostingStrategy Parameters for the gradient boosting algorithm.
*/
@Since("1.2.0")
def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0)
/**
* Method to train a gradient boosting model
*
@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate = true)
validate = true, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging {
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
new GradientBoostedTrees(boostingStrategy).run(input)
new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
}
/**
@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging {
* @param validationInput Validation dataset, ignored if validate is set to false.
* @param boostingStrategy Boosting parameters.
* @param validate Whether or not to use the validation dataset.
* @param seed Random seed.
* @return GradientBoostedTreesModel that can be used for prediction.
*/
private def boost(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
validate: Boolean): GradientBoostedTreesModel = {
validate: Boolean,
seed: Int): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(treeStrategy).run(input)
val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(treeStrategy).run(data)
val model = new DecisionTree(treeStrategy, seed + m).run(data)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model

View file

@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setLossType("logistic")
.setMaxIter(maxIter)
.setStepSize(learningRate)
.setSeed(123)
compareAPIs(data, None, gbt, categoricalFeatures)
}
}
@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
.setSeed(123)
val model = gbt.fit(df)
// copied model must have the same parent.
@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldBoostingStrategy =
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(oldBoostingStrategy)
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)

View file

@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setLossType(loss)
.setMaxIter(maxIter)
.setStepSize(learningRate)
.setSeed(123)
compareAPIs(data, None, gbt, categoricalFeatures)
}
}
@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
.setSeed(123)
val model = gbt.fit(df)
sc.checkpointDir = None
@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite {
categoricalFeatures: Map[Int, Int]): Unit = {
val numFeatures = data.first().features.size
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(oldBoostingStrategy)
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)

View file

@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
categoricalFeaturesInfo = Map.empty)
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
.runWithValidation(trainRdd, validateRdd)
val numTrees = gbtValidate.numTrees
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))