[SPARK-33592] Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
### What changes were proposed in this pull request? Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading When saving validator estimatorParamMaps, will check all nested stages in tuned estimator to get correct param parent. Two typical cases to manually test: ~~~python tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") lr = LogisticRegression() pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) paramGrid = ParamGridBuilder() \ .addGrid(hashingTF.numFeatures, [10, 100]) \ .addGrid(lr.maxIter, [100, 200]) \ .build() tvs = TrainValidationSplit(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) # check `loadedTvs.getEstimatorParamMaps()` restored correctly. ~~~ ~~~python lr = LogisticRegression() ova = OneVsRest(classifier=lr) grid = ParamGridBuilder().addGrid(lr.maxIter, [100, 200]).build() evaluator = MulticlassClassificationEvaluator() tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) # check `loadedTvs.getEstimatorParamMaps()` restored correctly. ~~~ ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #30539 from WeichenXu123/fix_tuning_param_maps_io. Authored-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
This commit is contained in:
parent
aeb3649fb9
commit
80161238fe
|
@ -564,6 +564,7 @@ pyspark_ml = Module(
|
|||
"pyspark.ml.tests.test_stat",
|
||||
"pyspark.ml.tests.test_training_summary",
|
||||
"pyspark.ml.tests.test_tuning",
|
||||
"pyspark.ml.tests.test_util",
|
||||
"pyspark.ml.tests.test_wrapper",
|
||||
],
|
||||
excluded_python_implementations=[
|
||||
|
|
|
@ -36,7 +36,7 @@ from pyspark.ml.base import _PredictorParams
|
|||
from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary
|
||||
from pyspark.ml.wrapper import JavaParams, \
|
||||
JavaPredictor, JavaPredictionModel, JavaWrapper
|
||||
from pyspark.ml.common import inherit_doc, _java2py, _py2java
|
||||
from pyspark.ml.common import inherit_doc
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.sql import DataFrame
|
||||
from pyspark.sql.functions import udf, when
|
||||
|
@ -2991,50 +2991,6 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
|
|||
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
|
||||
return _java_obj
|
||||
|
||||
def _make_java_param_pair(self, param, value):
|
||||
"""
|
||||
Makes a Java param pair.
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
param = self._resolveParam(param)
|
||||
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
|
||||
self.uid)
|
||||
java_param = _java_obj.getParam(param.name)
|
||||
if isinstance(value, JavaParams):
|
||||
# used in the case of an estimator having another estimator as a parameter
|
||||
# the reason why this is not in _py2java in common.py is that importing
|
||||
# Estimator and Model in common.py results in a circular import with inherit_doc
|
||||
java_value = value._to_java()
|
||||
else:
|
||||
java_value = _py2java(sc, value)
|
||||
return java_param.w(java_value)
|
||||
|
||||
def _transfer_param_map_to_java(self, pyParamMap):
|
||||
"""
|
||||
Transforms a Python ParamMap into a Java ParamMap.
|
||||
"""
|
||||
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
|
||||
for param in self.params:
|
||||
if param in pyParamMap:
|
||||
pair = self._make_java_param_pair(param, pyParamMap[param])
|
||||
paramMap.put([pair])
|
||||
return paramMap
|
||||
|
||||
def _transfer_param_map_from_java(self, javaParamMap):
|
||||
"""
|
||||
Transforms a Java ParamMap into a Python ParamMap.
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
paramMap = dict()
|
||||
for pair in javaParamMap.toList():
|
||||
param = pair.param()
|
||||
if self.hasParam(str(param.name())):
|
||||
if param.name() == "classifier":
|
||||
paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
|
||||
else:
|
||||
paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
|
||||
return paramMap
|
||||
|
||||
|
||||
class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
||||
"""
|
||||
|
|
|
@ -437,6 +437,12 @@ class Params(Identifiable, metaclass=ABCMeta):
|
|||
else:
|
||||
raise ValueError("Cannot resolve %r as a param." % param)
|
||||
|
||||
def _testOwnParam(self, param_parent, param_name):
|
||||
"""
|
||||
Test the ownership. Return True or False
|
||||
"""
|
||||
return self.uid == param_parent and self.hasParam(param_name)
|
||||
|
||||
@staticmethod
|
||||
def _dummy():
|
||||
"""
|
||||
|
|
|
@ -21,8 +21,8 @@ from pyspark.ml.base import Estimator, Model, Transformer
|
|||
from pyspark.ml.param import Param, Params
|
||||
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
|
||||
DefaultParamsReader, DefaultParamsWriter, MLWriter, MLReader, JavaMLWritable
|
||||
from pyspark.ml.wrapper import JavaParams, JavaWrapper
|
||||
from pyspark.ml.common import inherit_doc, _java2py, _py2java
|
||||
from pyspark.ml.wrapper import JavaParams
|
||||
from pyspark.ml.common import inherit_doc
|
||||
|
||||
|
||||
@inherit_doc
|
||||
|
@ -190,55 +190,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
|
|||
|
||||
return _java_obj
|
||||
|
||||
def _make_java_param_pair(self, param, value):
|
||||
"""
|
||||
Makes a Java param pair.
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
param = self._resolveParam(param)
|
||||
java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
|
||||
if isinstance(value, Params) and hasattr(value, "_to_java"):
|
||||
# Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which
|
||||
# implements `_to_java` method (such as OneVsRest, Pipeline object) to java object.
|
||||
# used in the case of an estimator having another estimator as a parameter
|
||||
# the reason why this is not in _py2java in common.py is that importing
|
||||
# Estimator and Model in common.py results in a circular import with inherit_doc
|
||||
java_value = value._to_java()
|
||||
else:
|
||||
java_value = _py2java(sc, value)
|
||||
return java_param.w(java_value)
|
||||
|
||||
def _transfer_param_map_to_java(self, pyParamMap):
|
||||
"""
|
||||
Transforms a Python ParamMap into a Java ParamMap.
|
||||
"""
|
||||
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
|
||||
for param in self.params:
|
||||
if param in pyParamMap:
|
||||
pair = self._make_java_param_pair(param, pyParamMap[param])
|
||||
paramMap.put([pair])
|
||||
return paramMap
|
||||
|
||||
def _transfer_param_map_from_java(self, javaParamMap):
|
||||
"""
|
||||
Transforms a Java ParamMap into a Python ParamMap.
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
paramMap = dict()
|
||||
for pair in javaParamMap.toList():
|
||||
param = pair.param()
|
||||
if self.hasParam(str(param.name())):
|
||||
java_obj = pair.value()
|
||||
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj):
|
||||
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
|
||||
# and Estimator/Transformer class which implements `_from_java` static method
|
||||
# (such as OneVsRest, Pipeline class).
|
||||
py_obj = JavaParams._from_java(java_obj)
|
||||
else:
|
||||
py_obj = _java2py(sc, java_obj)
|
||||
paramMap[self.getParam(param.name())] = py_obj
|
||||
return paramMap
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class PipelineWriter(MLWriter):
|
||||
|
|
|
@ -73,7 +73,21 @@ class ParamGridBuilderTests(SparkSessionTestCase):
|
|||
.build())
|
||||
|
||||
|
||||
class CrossValidatorTests(SparkSessionTestCase):
|
||||
class ValidatorTestUtilsMixin:
|
||||
def assert_param_maps_equal(self, paramMaps1, paramMaps2):
|
||||
self.assertEqual(len(paramMaps1), len(paramMaps2))
|
||||
for paramMap1, paramMap2 in zip(paramMaps1, paramMaps2):
|
||||
self.assertEqual(set(paramMap1.keys()), set(paramMap2.keys()))
|
||||
for param in paramMap1.keys():
|
||||
v1 = paramMap1[param]
|
||||
v2 = paramMap2[param]
|
||||
if isinstance(v1, Params):
|
||||
self.assertEqual(v1.uid, v2.uid)
|
||||
else:
|
||||
self.assertEqual(v1, v2)
|
||||
|
||||
|
||||
class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
||||
|
||||
def test_copy(self):
|
||||
dataset = self.spark.createDataFrame([
|
||||
|
@ -256,7 +270,7 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
loadedCV = CrossValidator.load(cvPath)
|
||||
self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
|
||||
self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
|
||||
self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
|
||||
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
|
||||
|
||||
# test save/load of CrossValidatorModel
|
||||
cvModelPath = temp_path + "/cvModel"
|
||||
|
@ -351,6 +365,7 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
cvPath = temp_path + "/cv"
|
||||
cv.save(cvPath)
|
||||
loadedCV = CrossValidator.load(cvPath)
|
||||
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), grid)
|
||||
self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
|
||||
self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
|
||||
|
||||
|
@ -367,6 +382,7 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
cvModelPath = temp_path + "/cvModel"
|
||||
cvModel.save(cvModelPath)
|
||||
loadedModel = CrossValidatorModel.load(cvModelPath)
|
||||
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
|
||||
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
|
||||
|
||||
def test_save_load_pipeline_estimator(self):
|
||||
|
@ -401,6 +417,11 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
estimatorParamMaps=paramGrid,
|
||||
evaluator=MulticlassClassificationEvaluator(),
|
||||
numFolds=2) # use 3+ folds in practice
|
||||
cvPath = temp_path + "/cv"
|
||||
crossval.save(cvPath)
|
||||
loadedCV = CrossValidator.load(cvPath)
|
||||
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), paramGrid)
|
||||
self.assertEqual(loadedCV.getEstimator().uid, crossval.getEstimator().uid)
|
||||
|
||||
# Run cross-validation, and choose the best set of parameters.
|
||||
cvModel = crossval.fit(training)
|
||||
|
@ -421,6 +442,11 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
estimatorParamMaps=paramGrid,
|
||||
evaluator=MulticlassClassificationEvaluator(),
|
||||
numFolds=2) # use 3+ folds in practice
|
||||
cv2Path = temp_path + "/cv2"
|
||||
crossval2.save(cv2Path)
|
||||
loadedCV2 = CrossValidator.load(cv2Path)
|
||||
self.assert_param_maps_equal(loadedCV2.getEstimatorParamMaps(), paramGrid)
|
||||
self.assertEqual(loadedCV2.getEstimator().uid, crossval2.getEstimator().uid)
|
||||
|
||||
# Run cross-validation, and choose the best set of parameters.
|
||||
cvModel2 = crossval2.fit(training)
|
||||
|
@ -511,7 +537,7 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
cv.fit(dataset_with_folds)
|
||||
|
||||
|
||||
class TrainValidationSplitTests(SparkSessionTestCase):
|
||||
class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
||||
|
||||
def test_fit_minimize_metric(self):
|
||||
dataset = self.spark.createDataFrame([
|
||||
|
@ -632,7 +658,8 @@ class TrainValidationSplitTests(SparkSessionTestCase):
|
|||
loadedTvs = TrainValidationSplit.load(tvsPath)
|
||||
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
|
||||
self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
|
||||
self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
|
||||
self.assert_param_maps_equal(
|
||||
loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
|
||||
|
||||
tvsModelPath = temp_path + "/tvsModel"
|
||||
tvsModel.save(tvsModelPath)
|
||||
|
@ -713,6 +740,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
|
|||
tvsPath = temp_path + "/tvs"
|
||||
tvs.save(tvsPath)
|
||||
loadedTvs = TrainValidationSplit.load(tvsPath)
|
||||
self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), grid)
|
||||
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
|
||||
self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
|
||||
|
||||
|
@ -728,6 +756,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
|
|||
tvsModelPath = temp_path + "/tvsModel"
|
||||
tvsModel.save(tvsModelPath)
|
||||
loadedModel = TrainValidationSplitModel.load(tvsModelPath)
|
||||
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
|
||||
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
|
||||
|
||||
def test_save_load_pipeline_estimator(self):
|
||||
|
@ -761,6 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase):
|
|||
tvs = TrainValidationSplit(estimator=pipeline,
|
||||
estimatorParamMaps=paramGrid,
|
||||
evaluator=MulticlassClassificationEvaluator())
|
||||
tvsPath = temp_path + "/tvs"
|
||||
tvs.save(tvsPath)
|
||||
loadedTvs = TrainValidationSplit.load(tvsPath)
|
||||
self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), paramGrid)
|
||||
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
|
||||
|
||||
# Run train validation split, and choose the best set of parameters.
|
||||
tvsModel = tvs.fit(training)
|
||||
|
@ -780,6 +814,11 @@ class TrainValidationSplitTests(SparkSessionTestCase):
|
|||
tvs2 = TrainValidationSplit(estimator=nested_pipeline,
|
||||
estimatorParamMaps=paramGrid,
|
||||
evaluator=MulticlassClassificationEvaluator())
|
||||
tvs2Path = temp_path + "/tvs2"
|
||||
tvs2.save(tvs2Path)
|
||||
loadedTvs2 = TrainValidationSplit.load(tvs2Path)
|
||||
self.assert_param_maps_equal(loadedTvs2.getEstimatorParamMaps(), paramGrid)
|
||||
self.assertEqual(loadedTvs2.getEstimator().uid, tvs2.getEstimator().uid)
|
||||
|
||||
# Run train validation split, and choose the best set of parameters.
|
||||
tvsModel2 = tvs2.fit(training)
|
||||
|
|
84
python/pyspark/ml/tests/test_util.py
Normal file
84
python/pyspark/ml/tests/test_util.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pyspark.ml import Pipeline
|
||||
from pyspark.ml.classification import LogisticRegression, OneVsRest
|
||||
from pyspark.ml.feature import VectorAssembler
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.util import MetaAlgorithmReadWrite
|
||||
from pyspark.testing.mlutils import SparkSessionTestCase
|
||||
|
||||
|
||||
class MetaAlgorithmReadWriteTests(SparkSessionTestCase):
|
||||
|
||||
def test_getAllNestedStages(self):
|
||||
def _check_uid_set_equal(stages, expected_stages):
|
||||
uids = set(map(lambda x: x.uid, stages))
|
||||
expected_uids = set(map(lambda x: x.uid, expected_stages))
|
||||
self.assertEqual(uids, expected_uids)
|
||||
|
||||
df1 = self.spark.createDataFrame([
|
||||
(Vectors.dense([1., 2.]), 1.0),
|
||||
(Vectors.dense([-1., -2.]), 0.0),
|
||||
], ['features', 'label'])
|
||||
df2 = self.spark.createDataFrame([
|
||||
(1., 2., 1.0),
|
||||
(1., 2., 0.0),
|
||||
], ['a', 'b', 'label'])
|
||||
vs = VectorAssembler(inputCols=['a', 'b'], outputCol='features')
|
||||
lr = LogisticRegression()
|
||||
pipeline = Pipeline(stages=[vs, lr])
|
||||
pipelineModel = pipeline.fit(df2)
|
||||
ova = OneVsRest(classifier=lr)
|
||||
ovaModel = ova.fit(df1)
|
||||
|
||||
ova_pipeline = Pipeline(stages=[vs, ova])
|
||||
nested_pipeline = Pipeline(stages=[ova_pipeline])
|
||||
|
||||
_check_uid_set_equal(
|
||||
MetaAlgorithmReadWrite.getAllNestedStages(pipeline),
|
||||
[pipeline, vs, lr]
|
||||
)
|
||||
_check_uid_set_equal(
|
||||
MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel),
|
||||
[pipelineModel] + pipelineModel.stages
|
||||
)
|
||||
_check_uid_set_equal(
|
||||
MetaAlgorithmReadWrite.getAllNestedStages(ova),
|
||||
[ova, lr]
|
||||
)
|
||||
_check_uid_set_equal(
|
||||
MetaAlgorithmReadWrite.getAllNestedStages(ovaModel),
|
||||
[ovaModel, lr] + ovaModel.models
|
||||
)
|
||||
_check_uid_set_equal(
|
||||
MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline),
|
||||
[nested_pipeline, ova_pipeline, vs, ova, lr]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.ml.tests.test_util import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
|
@ -26,8 +26,9 @@ from pyspark.ml import Estimator, Model
|
|||
from pyspark.ml.common import _py2java, _java2py
|
||||
from pyspark.ml.param import Params, Param, TypeConverters
|
||||
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
|
||||
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader
|
||||
from pyspark.ml.wrapper import JavaParams
|
||||
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
|
||||
MetaAlgorithmReadWrite
|
||||
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
|
||||
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
|
||||
from pyspark.sql.types import BooleanType
|
||||
|
||||
|
@ -64,6 +65,10 @@ def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
|
|||
|
||||
def singleTask():
|
||||
index, model = next(modelIter)
|
||||
# TODO: duplicate evaluator to take extra params from input
|
||||
# Note: Supporting tuning params in evaluator need update method
|
||||
# `MetaAlgorithmReadWrite.getAllNestedStages`, make it return
|
||||
# all nested stages and evaluators
|
||||
metric = eva.evaluate(model.transform(validation, epm[index]))
|
||||
return index, metric, model if collectSubModel else None
|
||||
|
||||
|
@ -186,8 +191,16 @@ class _ValidatorParams(HasSeed):
|
|||
# Load information from java_stage to the instance.
|
||||
estimator = JavaParams._from_java(java_stage.getEstimator())
|
||||
evaluator = JavaParams._from_java(java_stage.getEvaluator())
|
||||
epms = [estimator._transfer_param_map_from_java(epm)
|
||||
for epm in java_stage.getEstimatorParamMaps()]
|
||||
if isinstance(estimator, JavaEstimator):
|
||||
epms = [estimator._transfer_param_map_from_java(epm)
|
||||
for epm in java_stage.getEstimatorParamMaps()]
|
||||
elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
|
||||
# Meta estimator such as Pipeline, OneVsRest
|
||||
epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java(
|
||||
estimator, java_stage.getEstimatorParamMaps())
|
||||
else:
|
||||
raise ValueError('Unsupported estimator used in tuning: ' + str(estimator))
|
||||
|
||||
return estimator, epms, evaluator
|
||||
|
||||
def _to_java_impl(self):
|
||||
|
@ -198,15 +211,82 @@ class _ValidatorParams(HasSeed):
|
|||
gateway = SparkContext._gateway
|
||||
cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
|
||||
|
||||
java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
|
||||
for idx, epm in enumerate(self.getEstimatorParamMaps()):
|
||||
java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
|
||||
estimator = self.getEstimator()
|
||||
if isinstance(estimator, JavaEstimator):
|
||||
java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
|
||||
for idx, epm in enumerate(self.getEstimatorParamMaps()):
|
||||
java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
|
||||
elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
|
||||
# Meta estimator such as Pipeline, OneVsRest
|
||||
java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java(
|
||||
estimator, self.getEstimatorParamMaps())
|
||||
else:
|
||||
raise ValueError('Unsupported estimator used in tuning: ' + str(estimator))
|
||||
|
||||
java_estimator = self.getEstimator()._to_java()
|
||||
java_evaluator = self.getEvaluator()._to_java()
|
||||
return java_estimator, java_epms, java_evaluator
|
||||
|
||||
|
||||
class _ValidatorSharedReadWrite:
|
||||
@staticmethod
|
||||
def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps):
|
||||
pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
|
||||
stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages))
|
||||
sc = SparkContext._active_spark_context
|
||||
|
||||
paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
|
||||
javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps))
|
||||
|
||||
for idx, pyParamMap in enumerate(pyParamMaps):
|
||||
javaParamMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
|
||||
for pyParam, pyValue in pyParamMap.items():
|
||||
javaParam = None
|
||||
for pyStage, javaStage in stagePairs:
|
||||
if pyStage._testOwnParam(pyParam.parent, pyParam.name):
|
||||
javaParam = javaStage.getParam(pyParam.name)
|
||||
break
|
||||
if javaParam is None:
|
||||
raise ValueError('Resolve param in estimatorParamMaps failed: ' + str(pyParam))
|
||||
if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"):
|
||||
javaValue = pyValue._to_java()
|
||||
else:
|
||||
javaValue = _py2java(sc, pyValue)
|
||||
pair = javaParam.w(javaValue)
|
||||
javaParamMap.put([pair])
|
||||
javaParamMaps[idx] = javaParamMap
|
||||
return javaParamMaps
|
||||
|
||||
@staticmethod
|
||||
def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps):
|
||||
pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
|
||||
stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages))
|
||||
sc = SparkContext._active_spark_context
|
||||
pyParamMaps = []
|
||||
for javaParamMap in javaParamMaps:
|
||||
pyParamMap = dict()
|
||||
for javaPair in javaParamMap.toList():
|
||||
javaParam = javaPair.param()
|
||||
pyParam = None
|
||||
for pyStage, javaStage in stagePairs:
|
||||
if pyStage._testOwnParam(javaParam.parent(), javaParam.name()):
|
||||
pyParam = pyStage.getParam(javaParam.name())
|
||||
if pyParam is None:
|
||||
raise ValueError('Resolve param in estimatorParamMaps failed: ' +
|
||||
javaParam.parent() + '.' + javaParam.name())
|
||||
javaValue = javaPair.value()
|
||||
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue):
|
||||
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
|
||||
# and Estimator/Transformer class which implements `_from_java` static method
|
||||
# (such as OneVsRest, Pipeline class).
|
||||
pyValue = JavaParams._from_java(javaValue)
|
||||
else:
|
||||
pyValue = _java2py(sc, javaValue)
|
||||
pyParamMap[pyParam] = pyValue
|
||||
pyParamMaps.append(pyParamMap)
|
||||
return pyParamMaps
|
||||
|
||||
|
||||
class _CrossValidatorParams(_ValidatorParams):
|
||||
"""
|
||||
Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
|
||||
|
|
|
@ -592,3 +592,41 @@ class HasTrainingSummary(object):
|
|||
no summary exists.
|
||||
"""
|
||||
return (self._call_java("summary"))
|
||||
|
||||
|
||||
class MetaAlgorithmReadWrite:
|
||||
|
||||
@staticmethod
|
||||
def isMetaEstimator(pyInstance):
|
||||
from pyspark.ml import Estimator, Pipeline
|
||||
from pyspark.ml.tuning import _ValidatorParams
|
||||
from pyspark.ml.classification import OneVsRest
|
||||
return isinstance(pyInstance, Pipeline) or isinstance(pyInstance, OneVsRest) or \
|
||||
(isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams))
|
||||
|
||||
@staticmethod
|
||||
def getAllNestedStages(pyInstance):
|
||||
from pyspark.ml import Pipeline, PipelineModel
|
||||
from pyspark.ml.tuning import _ValidatorParams
|
||||
from pyspark.ml.classification import OneVsRest, OneVsRestModel
|
||||
|
||||
# TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel
|
||||
# support pipelineModel property.
|
||||
if isinstance(pyInstance, Pipeline):
|
||||
pySubStages = pyInstance.getStages()
|
||||
elif isinstance(pyInstance, PipelineModel):
|
||||
pySubStages = pyInstance.stages
|
||||
elif isinstance(pyInstance, _ValidatorParams):
|
||||
raise ValueError('PySpark does not support nested validator.')
|
||||
elif isinstance(pyInstance, OneVsRest):
|
||||
pySubStages = [pyInstance.getClassifier()]
|
||||
elif isinstance(pyInstance, OneVsRestModel):
|
||||
pySubStages = [pyInstance.getClassifier()] + pyInstance.models
|
||||
else:
|
||||
pySubStages = []
|
||||
|
||||
nestedStages = []
|
||||
for pySubStage in pySubStages:
|
||||
nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))
|
||||
|
||||
return [pyInstance] + nestedStages
|
||||
|
|
|
@ -126,3 +126,9 @@ class HasTrainingSummary(Generic[S]):
|
|||
def hasSummary(self) -> bool: ...
|
||||
@property
|
||||
def summary(self) -> S: ...
|
||||
|
||||
class MetaAlgorithmReadWrite:
|
||||
@staticmethod
|
||||
def isMetaEstimator(pyInstance: Any) -> bool: ...
|
||||
@staticmethod
|
||||
def getAllNestedStages(pyInstance: Any) -> list: ...
|
||||
|
|
Loading…
Reference in a new issue