[SPARK-30504][PYTHON][ML] Set weightCol in OneVsRest(Model) _to_java and _from_java
### What changes were proposed in this pull request? This PR adjusts `_to_java` and `_from_java` of `OneVsRest` and `OneVsRestModel` to preserve `weightCol`. ### Why are the changes needed? Currently both `Params` don't preserve `weightCol` `Params` when data is saved / loaded: ```python from pyspark.ml.classification import LogisticRegression, OneVsRest, OneVsRestModel from pyspark.ml.linalg import DenseVector df = spark.createDataFrame([(0, 1, DenseVector([1.0, 0.0])), (0, 1, DenseVector([1.0, 0.0]))], ("label", "w", "features")) ovr = OneVsRest(classifier=LogisticRegression()).setWeightCol("w") ovrm = ovr.fit(df) ovr.getWeightCol() ## 'w' ovrm.getWeightCol() ## 'w' ovr.write().overwrite().save("/tmp/ovr") ovr_ = OneVsRest.load("/tmp/ovr") ovr_.getWeightCol() ## KeyError ## ... ## KeyError: Param(parent='OneVsRest_5145d56b6bd1', name='weightCol', doc='weight column name. ...) ovrm.write().overwrite().save("/tmp/ovrm") ovrm_ = OneVsRestModel.load("/tmp/ovrm") ovrm_ .getWeightCol() ## KeyError ## ... ## KeyError: Param(parent='OneVsRestModel_598c6d900fad', name='weightCol', doc='weight column name ... ``` ### Does this PR introduce any user-facing change? After this PR is merged, loaded objects will have `weightCol` `Param` set. ### How was this patch tested? - Manual testing. - Extension of existing persistence tests. Closes #27190 from zero323/SPARK-30504. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
5f6cd61913
commit
525c5695f8
|
@ -2571,6 +2571,8 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
|
|||
py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
|
||||
rawPredictionCol=rawPredictionCol, classifier=classifier,
|
||||
parallelism=parallelism)
|
||||
if java_stage.isDefined(java_stage.getParam("weightCol")):
|
||||
py_stage.setWeightCol(java_stage.getWeightCol())
|
||||
py_stage._resetUid(java_stage.uid())
|
||||
return py_stage
|
||||
|
||||
|
@ -2587,6 +2589,8 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
|
|||
_java_obj.setFeaturesCol(self.getFeaturesCol())
|
||||
_java_obj.setLabelCol(self.getLabelCol())
|
||||
_java_obj.setPredictionCol(self.getPredictionCol())
|
||||
if (self.isDefined(self.weightCol) and self.getWeightCol()):
|
||||
_java_obj.setWeightCol(self.getWeightCol())
|
||||
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
|
||||
return _java_obj
|
||||
|
||||
|
@ -2765,6 +2769,8 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
|||
py_stage = cls(models=models).setPredictionCol(predictionCol)\
|
||||
.setFeaturesCol(featuresCol)
|
||||
py_stage._set(labelCol=labelCol)
|
||||
if java_stage.isDefined(java_stage.getParam("weightCol")):
|
||||
py_stage._set(weightCol=java_stage.getWeightCol())
|
||||
py_stage._set(classifier=classifier)
|
||||
py_stage._resetUid(java_stage.uid())
|
||||
return py_stage
|
||||
|
@ -2786,6 +2792,8 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
|||
_java_obj.set("featuresCol", self.getFeaturesCol())
|
||||
_java_obj.set("labelCol", self.getLabelCol())
|
||||
_java_obj.set("predictionCol", self.getPredictionCol())
|
||||
if (self.isDefined(self.weightCol) and self.getWeightCol()):
|
||||
_java_obj.set("weightCol", self.getWeightCol())
|
||||
return _java_obj
|
||||
|
||||
|
||||
|
|
|
@ -269,21 +269,27 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
|
||||
def test_onevsrest(self):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, Vectors.sparse(2, [], [])),
|
||||
(2.0, Vectors.dense(0.5, 0.5))] * 10,
|
||||
["label", "features"])
|
||||
df = self.spark.createDataFrame([(0.0, 0.5, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, 0.5, Vectors.sparse(2, [], [])),
|
||||
(2.0, 1.0, Vectors.dense(0.5, 0.5))] * 10,
|
||||
["label", "wt", "features"])
|
||||
|
||||
lr = LogisticRegression(maxIter=5, regParam=0.01)
|
||||
ovr = OneVsRest(classifier=lr)
|
||||
model = ovr.fit(df)
|
||||
ovrPath = temp_path + "/ovr"
|
||||
ovr.save(ovrPath)
|
||||
loadedOvr = OneVsRest.load(ovrPath)
|
||||
self._compare_pipelines(ovr, loadedOvr)
|
||||
modelPath = temp_path + "/ovrModel"
|
||||
model.save(modelPath)
|
||||
loadedModel = OneVsRestModel.load(modelPath)
|
||||
self._compare_pipelines(model, loadedModel)
|
||||
|
||||
def reload_and_compare(ovr, suffix):
|
||||
model = ovr.fit(df)
|
||||
ovrPath = temp_path + "/{}".format(suffix)
|
||||
ovr.save(ovrPath)
|
||||
loadedOvr = OneVsRest.load(ovrPath)
|
||||
self._compare_pipelines(ovr, loadedOvr)
|
||||
modelPath = temp_path + "/{}Model".format(suffix)
|
||||
model.save(modelPath)
|
||||
loadedModel = OneVsRestModel.load(modelPath)
|
||||
self._compare_pipelines(model, loadedModel)
|
||||
|
||||
reload_and_compare(OneVsRest(classifier=lr), "ovr")
|
||||
reload_and_compare(OneVsRest(classifier=lr).setWeightCol("wt"), "ovrw")
|
||||
|
||||
def test_decisiontree_classifier(self):
|
||||
dt = DecisionTreeClassifier(maxDepth=1)
|
||||
|
|
Loading…
Reference in a new issue