[SPARK-30504][PYTHON][ML] Set weightCol in OneVsRest(Model) _to_java and _from_java

### What changes were proposed in this pull request?

This PR adjusts `_to_java` and `_from_java` of `OneVsRest` and `OneVsRestModel` to preserve `weightCol`.

### Why are the changes needed?

Currently both `Params` don't preserve `weightCol` `Params` when data is saved / loaded:

```python
from pyspark.ml.classification import LogisticRegression, OneVsRest, OneVsRestModel
from pyspark.ml.linalg import DenseVector

df = spark.createDataFrame([(0, 1, DenseVector([1.0, 0.0])), (0, 1, DenseVector([1.0, 0.0]))], ("label", "w", "features"))

ovr = OneVsRest(classifier=LogisticRegression()).setWeightCol("w")
ovrm = ovr.fit(df)
ovr.getWeightCol()
## 'w'
ovrm.getWeightCol()
## 'w'

ovr.write().overwrite().save("/tmp/ovr")
ovr_ = OneVsRest.load("/tmp/ovr")
ovr_.getWeightCol()
## KeyError
## ...
## KeyError: Param(parent='OneVsRest_5145d56b6bd1', name='weightCol', doc='weight column name. ...)

ovrm.write().overwrite().save("/tmp/ovrm")
ovrm_ = OneVsRestModel.load("/tmp/ovrm")
ovrm_ .getWeightCol()
## KeyError
## ...
## KeyError: Param(parent='OneVsRestModel_598c6d900fad', name='weightCol', doc='weight column name ...
```

### Does this PR introduce any user-facing change?

After this PR is merged, loaded objects will have `weightCol` `Param` set.

### How was this patch tested?

- Manual testing.
- Extension of existing persistence tests.

Closes #27190 from zero323/SPARK-30504.

Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
zero323 2020-01-15 08:42:24 -06:00 committed by Sean Owen
parent 5f6cd61913
commit 525c5695f8
2 changed files with 27 additions and 13 deletions

View file

@ -2571,6 +2571,8 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
rawPredictionCol=rawPredictionCol, classifier=classifier,
parallelism=parallelism)
if java_stage.isDefined(java_stage.getParam("weightCol")):
py_stage.setWeightCol(java_stage.getWeightCol())
py_stage._resetUid(java_stage.uid())
return py_stage
@ -2587,6 +2589,8 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
_java_obj.setFeaturesCol(self.getFeaturesCol())
_java_obj.setLabelCol(self.getLabelCol())
_java_obj.setPredictionCol(self.getPredictionCol())
if (self.isDefined(self.weightCol) and self.getWeightCol()):
_java_obj.setWeightCol(self.getWeightCol())
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
return _java_obj
@ -2765,6 +2769,8 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
py_stage = cls(models=models).setPredictionCol(predictionCol)\
.setFeaturesCol(featuresCol)
py_stage._set(labelCol=labelCol)
if java_stage.isDefined(java_stage.getParam("weightCol")):
py_stage._set(weightCol=java_stage.getWeightCol())
py_stage._set(classifier=classifier)
py_stage._resetUid(java_stage.uid())
return py_stage
@ -2786,6 +2792,8 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
_java_obj.set("featuresCol", self.getFeaturesCol())
_java_obj.set("labelCol", self.getLabelCol())
_java_obj.set("predictionCol", self.getPredictionCol())
if (self.isDefined(self.weightCol) and self.getWeightCol()):
_java_obj.set("weightCol", self.getWeightCol())
return _java_obj

View file

@ -269,21 +269,27 @@ class PersistenceTest(SparkSessionTestCase):
def test_onevsrest(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))] * 10,
["label", "features"])
df = self.spark.createDataFrame([(0.0, 0.5, Vectors.dense(1.0, 0.8)),
(1.0, 0.5, Vectors.sparse(2, [], [])),
(2.0, 1.0, Vectors.dense(0.5, 0.5))] * 10,
["label", "wt", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
model = ovr.fit(df)
ovrPath = temp_path + "/ovr"
ovr.save(ovrPath)
loadedOvr = OneVsRest.load(ovrPath)
self._compare_pipelines(ovr, loadedOvr)
modelPath = temp_path + "/ovrModel"
model.save(modelPath)
loadedModel = OneVsRestModel.load(modelPath)
self._compare_pipelines(model, loadedModel)
def reload_and_compare(ovr, suffix):
model = ovr.fit(df)
ovrPath = temp_path + "/{}".format(suffix)
ovr.save(ovrPath)
loadedOvr = OneVsRest.load(ovrPath)
self._compare_pipelines(ovr, loadedOvr)
modelPath = temp_path + "/{}Model".format(suffix)
model.save(modelPath)
loadedModel = OneVsRestModel.load(modelPath)
self._compare_pipelines(model, loadedModel)
reload_and_compare(OneVsRest(classifier=lr), "ovr")
reload_and_compare(OneVsRest(classifier=lr).setWeightCol("wt"), "ovrw")
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)