[SPARK-15254][DOC] Improve ML pipeline Cross Validation Scaladoc & PyDoc
## What changes were proposed in this pull request? Updated ML pipeline Cross Validation Scaladoc & PyDoc. ## How was this patch tested? Documentation update (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: krishnakalyan3 <krishnakalyan3@gmail.com> Closes #13894 from krishnakalyan3/kfold-cv.
This commit is contained in:
parent
045fc36066
commit
7e8279fde1
|
@ -55,7 +55,11 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* K-fold cross validation.
|
* K-fold cross validation performs model selection by splitting the dataset into a set of
|
||||||
|
* non-overlapping randomly partitioned folds which are used as separate training and test datasets
|
||||||
|
* e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
|
||||||
|
* each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
|
||||||
|
* test set exactly once.
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
|
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
|
||||||
|
@ -188,7 +192,9 @@ object CrossValidator extends MLReadable[CrossValidator] {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model from k-fold cross validation.
|
* CrossValidatorModel contains the model with the highest average cross-validation
|
||||||
|
* metric across folds and uses this model to transform input data. CrossValidatorModel
|
||||||
|
* also tracks the metrics for each param map evaluated.
|
||||||
*
|
*
|
||||||
* @param bestModel The best model selected from k-fold cross validation.
|
* @param bestModel The best model selected from k-fold cross validation.
|
||||||
* @param avgMetrics Average cross-validation metrics for each paramMap in
|
* @param avgMetrics Average cross-validation metrics for each paramMap in
|
||||||
|
|
|
@ -143,7 +143,13 @@ class ValidatorParams(HasSeed):
|
||||||
|
|
||||||
class CrossValidator(Estimator, ValidatorParams):
|
class CrossValidator(Estimator, ValidatorParams):
|
||||||
"""
|
"""
|
||||||
K-fold cross validation.
|
|
||||||
|
K-fold cross validation performs model selection by splitting the dataset into a set of
|
||||||
|
non-overlapping randomly partitioned folds which are used as separate training and test datasets
|
||||||
|
e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
|
||||||
|
each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
|
||||||
|
test set exactly once.
|
||||||
|
|
||||||
|
|
||||||
>>> from pyspark.ml.classification import LogisticRegression
|
>>> from pyspark.ml.classification import LogisticRegression
|
||||||
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
|
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
|
||||||
|
@ -260,7 +266,10 @@ class CrossValidator(Estimator, ValidatorParams):
|
||||||
|
|
||||||
class CrossValidatorModel(Model, ValidatorParams):
|
class CrossValidatorModel(Model, ValidatorParams):
|
||||||
"""
|
"""
|
||||||
Model from k-fold cross validation.
|
|
||||||
|
CrossValidatorModel contains the model with the highest average cross-validation
|
||||||
|
metric across folds and uses this model to transform input data. CrossValidatorModel
|
||||||
|
also tracks the metrics for each param map evaluated.
|
||||||
|
|
||||||
.. versionadded:: 1.4.0
|
.. versionadded:: 1.4.0
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue