[SPARK-13952][ML] Add random seed to GBT
## What changes were proposed in this pull request? `GBTClassifier` and `GBTRegressor` should use random seed for reproducible results. Because of the nature of current unit tests, which compare GBTs in ML and GBTs in MLlib for equality, I also added a random seed to MLlib GBT algorithm. I made alternate constructors in `mllib.tree.GradientBoostedTrees` to accept a random seed, but left them as private so as to not change the API unnecessarily. ## How was this patch tested? Existing unit tests verify that functionality did not change. Other ML algorithms do not seem to have unit tests that directly test the functionality of random seeding, but reproducibility with seeding for GBTs is effectively verified in existing tests. I can add more tests if needed. Author: sethah <seth.hendrickson16@gmail.com> Closes #11903 from sethah/SPARK-13952.
This commit is contained in:
parent
5dfc01976b
commit
69bc2c17f1
|
@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") (
|
|||
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
|
||||
|
||||
@Since("1.4.0")
|
||||
override def setSeed(value: Long): this.type = {
|
||||
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
|
||||
super.setSeed(value)
|
||||
}
|
||||
override def setSeed(value: Long): this.type = super.setSeed(value)
|
||||
|
||||
// Parameters from GBTParams:
|
||||
|
||||
|
@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") (
|
|||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||
val numFeatures = oldDataset.first().features.size
|
||||
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
|
||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
|
||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
||||
$(seed))
|
||||
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
|
||||
}
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
|
|||
private[ml] def train(data: RDD[LabeledPoint],
|
||||
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
|
||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = 0L, parentUID = Some(uid))
|
||||
seed = $(seed), parentUID = Some(uid))
|
||||
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||
}
|
||||
|
||||
|
|
|
@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
|
|||
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
|
||||
|
||||
@Since("1.4.0")
|
||||
override def setSeed(value: Long): this.type = {
|
||||
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
|
||||
super.setSeed(value)
|
||||
}
|
||||
override def setSeed(value: Long): this.type = super.setSeed(value)
|
||||
|
||||
// Parameters from GBTParams:
|
||||
@Since("1.4.0")
|
||||
|
@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
|
|||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||
val numFeatures = oldDataset.first().features.size
|
||||
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
|
||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
|
||||
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
|
||||
$(seed))
|
||||
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
|
||||
}
|
||||
|
||||
|
|
|
@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
/**
|
||||
* Method to train a gradient boosting model
|
||||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
|
||||
* @param seed Random seed.
|
||||
* @return tuple of ensemble models and weights:
|
||||
* (array of decision tree models, array of model weights)
|
||||
*/
|
||||
def run(input: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy
|
||||
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
def run(
|
||||
input: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy,
|
||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
val algo = boostingStrategy.treeStrategy.algo
|
||||
algo match {
|
||||
case OldAlgo.Regression =>
|
||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
|
||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
|
||||
case OldAlgo.Classification =>
|
||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
|
||||
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
|
||||
seed)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
|
||||
}
|
||||
|
@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
* but it should follow the same distribution.
|
||||
* E.g., these two datasets could be created from an original dataset
|
||||
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
|
||||
* @param seed Random seed.
|
||||
* @return tuple of ensemble models and weights:
|
||||
* (array of decision tree models, array of model weights)
|
||||
*/
|
||||
def runWithValidation(
|
||||
input: RDD[LabeledPoint],
|
||||
validationInput: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy
|
||||
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
boostingStrategy: OldBoostingStrategy,
|
||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
val algo = boostingStrategy.treeStrategy.algo
|
||||
algo match {
|
||||
case OldAlgo.Regression =>
|
||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
|
||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
|
||||
case OldAlgo.Classification =>
|
||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||
val remappedInput = input.map(
|
||||
|
@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
val remappedValidationInput = validationInput.map(
|
||||
x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
|
||||
validate = true)
|
||||
validate = true, seed)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
|
||||
}
|
||||
|
@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
* @param validationInput validation dataset, ignored if validate is set to false.
|
||||
* @param boostingStrategy boosting parameters
|
||||
* @param validate whether or not to use the validation dataset.
|
||||
* @param seed Random seed.
|
||||
* @return tuple of ensemble models and weights:
|
||||
* (array of decision tree models, array of model weights)
|
||||
*/
|
||||
|
@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
input: RDD[LabeledPoint],
|
||||
validationInput: RDD[LabeledPoint],
|
||||
boostingStrategy: OldBoostingStrategy,
|
||||
validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
validate: Boolean,
|
||||
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
|
||||
val timer = new TimeTracker()
|
||||
timer.start("total")
|
||||
timer.start("init")
|
||||
|
@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
|
||||
// Initialize tree
|
||||
timer.start("building tree 0")
|
||||
val firstTree = new DecisionTreeRegressor()
|
||||
val firstTree = new DecisionTreeRegressor().setSeed(seed)
|
||||
val firstTreeModel = firstTree.train(input, treeStrategy)
|
||||
val firstTreeWeight = 1.0
|
||||
baseLearners(0) = firstTreeModel
|
||||
|
@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging {
|
|||
logDebug("###################################################")
|
||||
logDebug("Gradient boosting tree iteration " + m)
|
||||
logDebug("###################################################")
|
||||
val dt = new DecisionTreeRegressor()
|
||||
val dt = new DecisionTreeRegressor().setSeed(seed + m)
|
||||
val model = dt.train(data, treeStrategy)
|
||||
timer.stop(s"building tree $m")
|
||||
// Update partial model
|
||||
|
|
|
@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom
|
|||
* @param strategy The configuration parameters for the tree algorithm which specify the type
|
||||
* of decision tree (classification or regression), feature type (continuous,
|
||||
* categorical), depth of the tree, quantile calculation strategy, etc.
|
||||
* @param seed Random seed.
|
||||
*/
|
||||
@Since("1.0.0")
|
||||
class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
|
||||
class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
|
||||
extends Serializable with Logging {
|
||||
|
||||
/**
|
||||
* @param strategy The configuration parameters for the tree algorithm which specify the type
|
||||
* of decision tree (classification or regression), feature type (continuous,
|
||||
* categorical), depth of the tree, quantile calculation strategy, etc.
|
||||
*/
|
||||
@Since("1.0.0")
|
||||
def this(strategy: Strategy) = this(strategy, seed = 0)
|
||||
|
||||
strategy.assertValid()
|
||||
|
||||
/**
|
||||
|
@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
|
|||
*/
|
||||
@Since("1.2.0")
|
||||
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
|
||||
// Note: random seed will not be used since numTrees = 1.
|
||||
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
|
||||
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = seed)
|
||||
val rfModel = rf.run(input)
|
||||
rfModel.trees(0)
|
||||
}
|
||||
|
|
|
@ -47,11 +47,20 @@ import org.apache.spark.storage.StorageLevel
|
|||
* for other loss functions.
|
||||
*
|
||||
* @param boostingStrategy Parameters for the gradient boosting algorithm.
|
||||
* @param seed Random seed.
|
||||
*/
|
||||
@Since("1.2.0")
|
||||
class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
|
||||
class GradientBoostedTrees private[spark] (
|
||||
private val boostingStrategy: BoostingStrategy,
|
||||
private val seed: Int)
|
||||
extends Serializable with Logging {
|
||||
|
||||
/**
|
||||
* @param boostingStrategy Parameters for the gradient boosting algorithm.
|
||||
*/
|
||||
@Since("1.2.0")
|
||||
def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0)
|
||||
|
||||
/**
|
||||
* Method to train a gradient boosting model
|
||||
*
|
||||
|
@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
|
|||
val algo = boostingStrategy.treeStrategy.algo
|
||||
algo match {
|
||||
case Regression =>
|
||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
|
||||
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
|
||||
case Classification =>
|
||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
|
||||
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
|
||||
seed)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
|
||||
}
|
||||
|
@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
|
|||
val algo = boostingStrategy.treeStrategy.algo
|
||||
algo match {
|
||||
case Regression =>
|
||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
|
||||
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
|
||||
case Classification =>
|
||||
// Map labels to -1, +1 so binary classification can be treated as regression.
|
||||
val remappedInput = input.map(
|
||||
|
@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
|
|||
val remappedValidationInput = validationInput.map(
|
||||
x => new LabeledPoint((x.label * 2) - 1, x.features))
|
||||
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
|
||||
validate = true)
|
||||
validate = true, seed)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
|
||||
}
|
||||
|
@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging {
|
|||
def train(
|
||||
input: RDD[LabeledPoint],
|
||||
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
|
||||
new GradientBoostedTrees(boostingStrategy).run(input)
|
||||
new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging {
|
|||
* @param validationInput Validation dataset, ignored if validate is set to false.
|
||||
* @param boostingStrategy Boosting parameters.
|
||||
* @param validate Whether or not to use the validation dataset.
|
||||
* @param seed Random seed.
|
||||
* @return GradientBoostedTreesModel that can be used for prediction.
|
||||
*/
|
||||
private def boost(
|
||||
input: RDD[LabeledPoint],
|
||||
validationInput: RDD[LabeledPoint],
|
||||
boostingStrategy: BoostingStrategy,
|
||||
validate: Boolean): GradientBoostedTreesModel = {
|
||||
validate: Boolean,
|
||||
seed: Int): GradientBoostedTreesModel = {
|
||||
val timer = new TimeTracker()
|
||||
timer.start("total")
|
||||
timer.start("init")
|
||||
|
@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging {
|
|||
|
||||
// Initialize tree
|
||||
timer.start("building tree 0")
|
||||
val firstTreeModel = new DecisionTree(treeStrategy).run(input)
|
||||
val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
|
||||
val firstTreeWeight = 1.0
|
||||
baseLearners(0) = firstTreeModel
|
||||
baseLearnerWeights(0) = firstTreeWeight
|
||||
|
@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging {
|
|||
logDebug("###################################################")
|
||||
logDebug("Gradient boosting tree iteration " + m)
|
||||
logDebug("###################################################")
|
||||
val model = new DecisionTree(treeStrategy).run(data)
|
||||
val model = new DecisionTree(treeStrategy, seed + m).run(data)
|
||||
timer.stop(s"building tree $m")
|
||||
// Update partial model
|
||||
baseLearners(m) = model
|
||||
|
|
|
@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
.setLossType("logistic")
|
||||
.setMaxIter(maxIter)
|
||||
.setStepSize(learningRate)
|
||||
.setSeed(123)
|
||||
compareAPIs(data, None, gbt, categoricalFeatures)
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
.setMaxIter(5)
|
||||
.setStepSize(0.1)
|
||||
.setCheckpointInterval(2)
|
||||
.setSeed(123)
|
||||
val model = gbt.fit(df)
|
||||
|
||||
// copied model must have the same parent.
|
||||
|
@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite {
|
|||
val numFeatures = data.first().features.size
|
||||
val oldBoostingStrategy =
|
||||
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
|
||||
val oldGBT = new OldGBT(oldBoostingStrategy)
|
||||
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
|
||||
val oldModel = oldGBT.run(data)
|
||||
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
|
||||
val newModel = gbt.fit(newData)
|
||||
|
|
|
@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
.setLossType(loss)
|
||||
.setMaxIter(maxIter)
|
||||
.setStepSize(learningRate)
|
||||
.setSeed(123)
|
||||
compareAPIs(data, None, gbt, categoricalFeatures)
|
||||
}
|
||||
}
|
||||
|
@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
.setMaxIter(5)
|
||||
.setStepSize(0.1)
|
||||
.setCheckpointInterval(2)
|
||||
.setSeed(123)
|
||||
val model = gbt.fit(df)
|
||||
|
||||
sc.checkpointDir = None
|
||||
|
@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite {
|
|||
categoricalFeatures: Map[Int, Int]): Unit = {
|
||||
val numFeatures = data.first().features.size
|
||||
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
|
||||
val oldGBT = new OldGBT(oldBoostingStrategy)
|
||||
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
|
||||
val oldModel = oldGBT.run(data)
|
||||
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
|
||||
val newModel = gbt.fit(newData)
|
||||
|
|
|
@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
categoricalFeaturesInfo = Map.empty)
|
||||
val boostingStrategy =
|
||||
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
|
||||
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
|
||||
val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
|
||||
.runWithValidation(trainRdd, validateRdd)
|
||||
val numTrees = gbtValidate.numTrees
|
||||
assert(numTrees !== numIterations)
|
||||
|
||||
// Test that it performs better on the validation dataset.
|
||||
val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
|
||||
val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
|
||||
val (errorWithoutValidation, errorWithValidation) = {
|
||||
if (algo == Classification) {
|
||||
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
|
||||
|
|
Loading…
Reference in a new issue