[SPARK-7474] [MLLIB] update ParamGridBuilder doctest
Multiline commands are properly handled in this PR. oefirouz ![screen shot 2015-05-07 at 10 53 25 pm](https://cloud.githubusercontent.com/assets/829644/7531290/02ad2fd4-f50c-11e4-8c04-e58d1a61ad69.png) Author: Xiangrui Meng <meng@databricks.com> Closes #6001 from mengxr/SPARK-7474 and squashes the following commits: b94b11d [Xiangrui Meng] update ParamGridBuilder doctest
This commit is contained in:
parent
f5ff4a84c4
commit
65afd3ce8b
|
@ -27,24 +27,22 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
|
|||
|
||||
|
||||
class ParamGridBuilder(object):
|
||||
"""
|
||||
r"""
|
||||
Builder for a param grid used in grid search-based model selection.
|
||||
|
||||
>>> from classification import LogisticRegression
|
||||
>>> from pyspark.ml.classification import LogisticRegression
|
||||
>>> lr = LogisticRegression()
|
||||
>>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \
|
||||
.baseOn([lr.predictionCol, 'p']) \
|
||||
.addGrid(lr.regParam, [1.0, 2.0, 3.0]) \
|
||||
.addGrid(lr.maxIter, [1, 5]) \
|
||||
.addGrid(lr.featuresCol, ['f']) \
|
||||
.build()
|
||||
>>> expected = [ \
|
||||
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
|
||||
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
|
||||
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
|
||||
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
|
||||
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
|
||||
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
|
||||
>>> output = ParamGridBuilder() \
|
||||
... .baseOn({lr.labelCol: 'l'}) \
|
||||
... .baseOn([lr.predictionCol, 'p']) \
|
||||
... .addGrid(lr.regParam, [1.0, 2.0]) \
|
||||
... .addGrid(lr.maxIter, [1, 5]) \
|
||||
... .build()
|
||||
>>> expected = [
|
||||
... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
|
||||
... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
|
||||
... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
|
||||
... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
|
||||
>>> len(output) == len(expected)
|
||||
True
|
||||
>>> all([m in expected for m in output])
|
||||
|
|
Loading…
Reference in a new issue