[SPARK-28969][PYTHON][ML] OneVsRestParams parity between scala and python
### What changes were proposed in this pull request? Follow the scala ```OneVsRestParams``` implementation, move ```setClassifier``` from ```OneVsRestParams``` to ```OneVsRest``` in Pyspark ### Why are the changes needed? 1. Maintain the parity between scala and python code. 2. ```Classifier``` can only be set in the estimator. ### Does this PR introduce any user-facing change? Yes. Previous behavior: ```OneVsRestModel``` has method ```setClassifier``` Current behavior: ```setClassifier``` is removed from ```OneVsRestModel```. ```classifier``` can only be set in ```OneVsRest```. ### How was this patch tested? Use existing tests Closes #25715 from huaxingao/spark-28969. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
fcf9b41b49
commit
77e9b58d4f
|
@ -1872,15 +1872,6 @@ class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCo
|
|||
|
||||
classifier = Param(Params._dummy(), "classifier", "base binary classifier")
|
||||
|
||||
@since("2.0.0")
|
||||
def setClassifier(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`classifier`.
|
||||
|
||||
.. note:: Only LogisticRegression and NaiveBayes are supported now.
|
||||
"""
|
||||
return self._set(classifier=value)
|
||||
|
||||
@since("2.0.0")
|
||||
def getClassifier(self):
|
||||
"""
|
||||
|
@ -1959,6 +1950,13 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java
|
|||
kwargs = self._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
||||
@since("2.0.0")
|
||||
def setClassifier(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`classifier`.
|
||||
"""
|
||||
return self._set(classifier=value)
|
||||
|
||||
def _fit(self, dataset):
|
||||
labelCol = self.getLabelCol()
|
||||
featuresCol = self.getFeaturesCol()
|
||||
|
@ -2212,7 +2210,8 @@ class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
|||
classifier = JavaParams._from_java(java_stage.getClassifier())
|
||||
models = [JavaParams._from_java(model) for model in java_stage.models()]
|
||||
py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\
|
||||
.setFeaturesCol(featuresCol).setClassifier(classifier)
|
||||
.setFeaturesCol(featuresCol)
|
||||
py_stage._set(classifier=classifier)
|
||||
py_stage._resetUid(java_stage.uid())
|
||||
return py_stage
|
||||
|
||||
|
|
Loading…
Reference in a new issue