[SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree*
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #9402 from yu-iskw/SPARK-9722.
This commit is contained in:
parent
3e770a64a4
commit
e963070c13
|
@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging {
|
|||
// Find the splits and the corresponding bins (interval between the splits) using a sample
|
||||
// of the input data.
|
||||
timer.start("findSplitsBins")
|
||||
val splits = findSplits(retaggedInput, metadata)
|
||||
val splits = findSplits(retaggedInput, metadata, seed)
|
||||
timer.stop("findSplitsBins")
|
||||
logDebug("numBins: feature: number of bins")
|
||||
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
|
||||
|
@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging {
|
|||
*
|
||||
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
|
||||
* @param metadata Learning and dataset metadata
|
||||
* @param seed random seed
|
||||
* @return A tuple of (splits, bins).
|
||||
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
|
||||
* of size (numFeatures, numSplits).
|
||||
|
@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging {
|
|||
*/
|
||||
protected[tree] def findSplits(
|
||||
input: RDD[LabeledPoint],
|
||||
metadata: DecisionTreeMetadata): Array[Array[Split]] = {
|
||||
metadata: DecisionTreeMetadata,
|
||||
seed : Long): Array[Array[Split]] = {
|
||||
|
||||
logDebug("isMulticlass = " + metadata.isMulticlass)
|
||||
|
||||
|
@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging {
|
|||
1.0
|
||||
}
|
||||
logDebug("fraction of data used for calculating quantiles = " + fraction)
|
||||
input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
|
||||
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
|
||||
} else {
|
||||
new Array[LabeledPoint](0)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue