[SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree*

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #9402 from yu-iskw/SPARK-9722.
This commit is contained in:
Yu ISHIKAWA 2015-11-01 23:52:50 -08:00 committed by DB Tsai
parent 3e770a64a4
commit e963070c13

View file

@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging {
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
val splits = findSplits(retaggedInput, metadata)
val splits = findSplits(retaggedInput, metadata, seed)
timer.stop("findSplitsBins")
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging {
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param metadata Learning and dataset metadata
* @param seed random seed
* @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
* of size (numFeatures, numSplits).
@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging {
*/
protected[tree] def findSplits(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata): Array[Array[Split]] = {
metadata: DecisionTreeMetadata,
seed : Long): Array[Array[Split]] = {
logDebug("isMulticlass = " + metadata.isMulticlass)
@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
} else {
new Array[LabeledPoint](0)
}