[SPARK-9694][ML] Add random seed Param to Scala CrossValidator
Add random seed Param to Scala CrossValidator Author: Yanbo Liang <ybliang8@gmail.com> Closes #9108 from yanboliang/spark-9694.
This commit is contained in:
parent
7b6dc29d0e
commit
860dc7f2f8
|
@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams
|
|||
import org.apache.spark.ml.evaluation.Evaluator
|
||||
import org.apache.spark.ml.feature.RFormulaModel
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||
import org.apache.spark.ml.param.shared.HasSeed
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType
|
|||
/**
|
||||
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
|
||||
*/
|
||||
private[ml] trait CrossValidatorParams extends ValidatorParams {
|
||||
private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
|
||||
/**
|
||||
* Param for number of folds for cross validation. Must be >= 2.
|
||||
* Default: 3
|
||||
|
@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
|
|||
@Since("1.2.0")
|
||||
def setNumFolds(value: Int): this.type = set(numFolds, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.0.0")
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
@Since("1.4.0")
|
||||
override def fit(dataset: DataFrame): CrossValidatorModel = {
|
||||
val schema = dataset.schema
|
||||
|
@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
|
|||
val epm = $(estimatorParamMaps)
|
||||
val numModels = epm.length
|
||||
val metrics = new Array[Double](epm.length)
|
||||
val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0)
|
||||
val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
|
||||
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
|
||||
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
|
||||
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
|
||||
|
|
|
@ -265,6 +265,14 @@ object MLUtils {
|
|||
*/
|
||||
@Since("1.0.0")
|
||||
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
|
||||
kFold(rdd, numFolds, seed.toLong)
|
||||
}
|
||||
|
||||
/**
|
||||
* Version of [[kFold()]] taking a Long seed.
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = {
|
||||
val numFoldsF = numFolds.toFloat
|
||||
(1 to numFolds).map { fold =>
|
||||
val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
|
||||
|
|
Loading…
Reference in a new issue