[SPARK-9694][ML] Add random seed Param to Scala CrossValidator

Add random seed Param to Scala CrossValidator

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9108 from yanboliang/spark-9694.
This commit is contained in:
Yanbo Liang 2015-12-16 11:05:37 -08:00 committed by Joseph K. Bradley
parent 7b6dc29d0e
commit 860dc7f2f8
2 changed files with 16 additions and 3 deletions

View file

@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
private[ml] trait CrossValidatorParams extends ValidatorParams {
private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("1.2.0")
def setNumFolds(value: Int): this.type = set(numFolds, value)
/** @group setParam */
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
@Since("1.4.0")
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0)
val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()

View file

@ -265,6 +265,14 @@ object MLUtils {
*/
@Since("1.0.0")
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
kFold(rdd, numFolds, seed.toLong)
}
/**
* Version of [[kFold()]] taking a Long seed.
*/
@Since("2.0.0")
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,