From 80161238fe9393aabd5fcd56752ff1e43f6989b1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 1 Dec 2020 09:36:42 +0800 Subject: [PATCH] [SPARK-33592] Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading ### What changes were proposed in this pull request? Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading When saving validator estimatorParamMaps, will check all nested stages in tuned estimator to get correct param parent. Two typical cases to manually test: ~~~python tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") lr = LogisticRegression() pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) paramGrid = ParamGridBuilder() \ .addGrid(hashingTF.numFeatures, [10, 100]) \ .addGrid(lr.maxIter, [100, 200]) \ .build() tvs = TrainValidationSplit(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) # check `loadedTvs.getEstimatorParamMaps()` restored correctly. ~~~ ~~~python lr = LogisticRegression() ova = OneVsRest(classifier=lr) grid = ParamGridBuilder().addGrid(lr.maxIter, [100, 200]).build() evaluator = MulticlassClassificationEvaluator() tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) # check `loadedTvs.getEstimatorParamMaps()` restored correctly. ~~~ ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #30539 from WeichenXu123/fix_tuning_param_maps_io. Authored-by: Weichen Xu Signed-off-by: Ruifeng Zheng --- dev/sparktestsupport/modules.py | 1 + python/pyspark/ml/classification.py | 46 +------------ python/pyspark/ml/param/__init__.py | 6 ++ python/pyspark/ml/pipeline.py | 53 +-------------- python/pyspark/ml/tests/test_tuning.py | 47 +++++++++++-- python/pyspark/ml/tests/test_util.py | 84 +++++++++++++++++++++++ python/pyspark/ml/tuning.py | 94 ++++++++++++++++++++++++-- python/pyspark/ml/util.py | 38 +++++++++++ python/pyspark/ml/util.pyi | 6 ++ 9 files changed, 268 insertions(+), 107 deletions(-) create mode 100644 python/pyspark/ml/tests/test_util.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 868e4a5d23..5d8b714711 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -564,6 +564,7 @@ pyspark_ml = Module( "pyspark.ml.tests.test_stat", "pyspark.ml.tests.test_training_summary", "pyspark.ml.tests.test_tuning", + "pyspark.ml.tests.test_util", "pyspark.ml.tests.test_wrapper", ], excluded_python_implementations=[ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 50882fc895..763038ede8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,7 @@ from pyspark.ml.base import _PredictorParams from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary from pyspark.ml.wrapper import JavaParams, \ JavaPredictor, JavaPredictionModel, JavaWrapper -from pyspark.ml.common import inherit_doc, _java2py, _py2java +from pyspark.ml.common import inherit_doc from pyspark.ml.linalg import Vectors from pyspark.sql import DataFrame from pyspark.sql.functions import udf, when @@ -2991,50 +2991,6 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav _java_obj.setRawPredictionCol(self.getRawPredictionCol()) return _java_obj - def _make_java_param_pair(self, param, value): - """ - Makes a Java param pair. - """ - sc = SparkContext._active_spark_context - param = self._resolveParam(param) - _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest", - self.uid) - java_param = _java_obj.getParam(param.name) - if isinstance(value, JavaParams): - # used in the case of an estimator having another estimator as a parameter - # the reason why this is not in _py2java in common.py is that importing - # Estimator and Model in common.py results in a circular import with inherit_doc - java_value = value._to_java() - else: - java_value = _py2java(sc, value) - return java_param.w(java_value) - - def _transfer_param_map_to_java(self, pyParamMap): - """ - Transforms a Python ParamMap into a Java ParamMap. - """ - paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") - for param in self.params: - if param in pyParamMap: - pair = self._make_java_param_pair(param, pyParamMap[param]) - paramMap.put([pair]) - return paramMap - - def _transfer_param_map_from_java(self, javaParamMap): - """ - Transforms a Java ParamMap into a Python ParamMap. - """ - sc = SparkContext._active_spark_context - paramMap = dict() - for pair in javaParamMap.toList(): - param = pair.param() - if self.hasParam(str(param.name())): - if param.name() == "classifier": - paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value()) - else: - paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) - return paramMap - class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index f2381a4c42..3eab6607aa 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -437,6 +437,12 @@ class Params(Identifiable, metaclass=ABCMeta): else: raise ValueError("Cannot resolve %r as a param." % param) + def _testOwnParam(self, param_parent, param_name): + """ + Test the ownership. Return True or False + """ + return self.uid == param_parent and self.hasParam(param_name) + @staticmethod def _dummy(): """ diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a6471a8dd1..b0aa735709 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -21,8 +21,8 @@ from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \ DefaultParamsReader, DefaultParamsWriter, MLWriter, MLReader, JavaMLWritable -from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.ml.common import inherit_doc, _java2py, _py2java +from pyspark.ml.wrapper import JavaParams +from pyspark.ml.common import inherit_doc @inherit_doc @@ -190,55 +190,6 @@ class Pipeline(Estimator, MLReadable, MLWritable): return _java_obj - def _make_java_param_pair(self, param, value): - """ - Makes a Java param pair. - """ - sc = SparkContext._active_spark_context - param = self._resolveParam(param) - java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc) - if isinstance(value, Params) and hasattr(value, "_to_java"): - # Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which - # implements `_to_java` method (such as OneVsRest, Pipeline object) to java object. - # used in the case of an estimator having another estimator as a parameter - # the reason why this is not in _py2java in common.py is that importing - # Estimator and Model in common.py results in a circular import with inherit_doc - java_value = value._to_java() - else: - java_value = _py2java(sc, value) - return java_param.w(java_value) - - def _transfer_param_map_to_java(self, pyParamMap): - """ - Transforms a Python ParamMap into a Java ParamMap. - """ - paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") - for param in self.params: - if param in pyParamMap: - pair = self._make_java_param_pair(param, pyParamMap[param]) - paramMap.put([pair]) - return paramMap - - def _transfer_param_map_from_java(self, javaParamMap): - """ - Transforms a Java ParamMap into a Python ParamMap. - """ - sc = SparkContext._active_spark_context - paramMap = dict() - for pair in javaParamMap.toList(): - param = pair.param() - if self.hasParam(str(param.name())): - java_obj = pair.value() - if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj): - # Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class - # and Estimator/Transformer class which implements `_from_java` static method - # (such as OneVsRest, Pipeline class). - py_obj = JavaParams._from_java(java_obj) - else: - py_obj = _java2py(sc, java_obj) - paramMap[self.getParam(param.name())] = py_obj - return paramMap - @inherit_doc class PipelineWriter(MLWriter): diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index ced32c07f2..ebd7457e4d 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -73,7 +73,21 @@ class ParamGridBuilderTests(SparkSessionTestCase): .build()) -class CrossValidatorTests(SparkSessionTestCase): +class ValidatorTestUtilsMixin: + def assert_param_maps_equal(self, paramMaps1, paramMaps2): + self.assertEqual(len(paramMaps1), len(paramMaps2)) + for paramMap1, paramMap2 in zip(paramMaps1, paramMaps2): + self.assertEqual(set(paramMap1.keys()), set(paramMap2.keys())) + for param in paramMap1.keys(): + v1 = paramMap1[param] + v2 = paramMap2[param] + if isinstance(v1, Params): + self.assertEqual(v1.uid, v2.uid) + else: + self.assertEqual(v1, v2) + + +class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): def test_copy(self): dataset = self.spark.createDataFrame([ @@ -256,7 +270,7 @@ class CrossValidatorTests(SparkSessionTestCase): loadedCV = CrossValidator.load(cvPath) self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) # test save/load of CrossValidatorModel cvModelPath = temp_path + "/cvModel" @@ -351,6 +365,7 @@ class CrossValidatorTests(SparkSessionTestCase): cvPath = temp_path + "/cv" cv.save(cvPath) loadedCV = CrossValidator.load(cvPath) + self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), grid) self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) @@ -367,6 +382,7 @@ class CrossValidatorTests(SparkSessionTestCase): cvModelPath = temp_path + "/cvModel" cvModel.save(cvModelPath) loadedModel = CrossValidatorModel.load(cvModelPath) + self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) def test_save_load_pipeline_estimator(self): @@ -401,6 +417,11 @@ class CrossValidatorTests(SparkSessionTestCase): estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator(), numFolds=2) # use 3+ folds in practice + cvPath = temp_path + "/cv" + crossval.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedCV.getEstimator().uid, crossval.getEstimator().uid) # Run cross-validation, and choose the best set of parameters. cvModel = crossval.fit(training) @@ -421,6 +442,11 @@ class CrossValidatorTests(SparkSessionTestCase): estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator(), numFolds=2) # use 3+ folds in practice + cv2Path = temp_path + "/cv2" + crossval2.save(cv2Path) + loadedCV2 = CrossValidator.load(cv2Path) + self.assert_param_maps_equal(loadedCV2.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedCV2.getEstimator().uid, crossval2.getEstimator().uid) # Run cross-validation, and choose the best set of parameters. cvModel2 = crossval2.fit(training) @@ -511,7 +537,7 @@ class CrossValidatorTests(SparkSessionTestCase): cv.fit(dataset_with_folds) -class TrainValidationSplitTests(SparkSessionTestCase): +class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): def test_fit_minimize_metric(self): dataset = self.spark.createDataFrame([ @@ -632,7 +658,8 @@ class TrainValidationSplitTests(SparkSessionTestCase): loadedTvs = TrainValidationSplit.load(tvsPath) self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + self.assert_param_maps_equal( + loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) tvsModelPath = temp_path + "/tvsModel" tvsModel.save(tvsModelPath) @@ -713,6 +740,7 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvsPath = temp_path + "/tvs" tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) + self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), grid) self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) @@ -728,6 +756,7 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvsModelPath = temp_path + "/tvsModel" tvsModel.save(tvsModelPath) loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) def test_save_load_pipeline_estimator(self): @@ -761,6 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvs = TrainValidationSplit(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) # Run train validation split, and choose the best set of parameters. tvsModel = tvs.fit(training) @@ -780,6 +814,11 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvs2 = TrainValidationSplit(estimator=nested_pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) + tvs2Path = temp_path + "/tvs2" + tvs2.save(tvs2Path) + loadedTvs2 = TrainValidationSplit.load(tvs2Path) + self.assert_param_maps_equal(loadedTvs2.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedTvs2.getEstimator().uid, tvs2.getEstimator().uid) # Run train validation split, and choose the best set of parameters. tvsModel2 = tvs2.fit(training) diff --git a/python/pyspark/ml/tests/test_util.py b/python/pyspark/ml/tests/test_util.py new file mode 100644 index 0000000000..498a649e48 --- /dev/null +++ b/python/pyspark/ml/tests/test_util.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression, OneVsRest +from pyspark.ml.feature import VectorAssembler +from pyspark.ml.linalg import Vectors +from pyspark.ml.util import MetaAlgorithmReadWrite +from pyspark.testing.mlutils import SparkSessionTestCase + + +class MetaAlgorithmReadWriteTests(SparkSessionTestCase): + + def test_getAllNestedStages(self): + def _check_uid_set_equal(stages, expected_stages): + uids = set(map(lambda x: x.uid, stages)) + expected_uids = set(map(lambda x: x.uid, expected_stages)) + self.assertEqual(uids, expected_uids) + + df1 = self.spark.createDataFrame([ + (Vectors.dense([1., 2.]), 1.0), + (Vectors.dense([-1., -2.]), 0.0), + ], ['features', 'label']) + df2 = self.spark.createDataFrame([ + (1., 2., 1.0), + (1., 2., 0.0), + ], ['a', 'b', 'label']) + vs = VectorAssembler(inputCols=['a', 'b'], outputCol='features') + lr = LogisticRegression() + pipeline = Pipeline(stages=[vs, lr]) + pipelineModel = pipeline.fit(df2) + ova = OneVsRest(classifier=lr) + ovaModel = ova.fit(df1) + + ova_pipeline = Pipeline(stages=[vs, ova]) + nested_pipeline = Pipeline(stages=[ova_pipeline]) + + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(pipeline), + [pipeline, vs, lr] + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel), + [pipelineModel] + pipelineModel.stages + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(ova), + [ova, lr] + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(ovaModel), + [ovaModel, lr] + ovaModel.models + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline), + [nested_pipeline, ova_pipeline, vs, ova, lr] + ) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_util import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6f4ad99484..2b5a9857b0 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -26,8 +26,9 @@ from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java, _java2py from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed -from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader -from pyspark.ml.wrapper import JavaParams +from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \ + MetaAlgorithmReadWrite +from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper from pyspark.sql.functions import col, lit, rand, UserDefinedFunction from pyspark.sql.types import BooleanType @@ -64,6 +65,10 @@ def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): def singleTask(): index, model = next(modelIter) + # TODO: duplicate evaluator to take extra params from input + # Note: Supporting tuning params in evaluator need update method + # `MetaAlgorithmReadWrite.getAllNestedStages`, make it return + # all nested stages and evaluators metric = eva.evaluate(model.transform(validation, epm[index])) return index, metric, model if collectSubModel else None @@ -186,8 +191,16 @@ class _ValidatorParams(HasSeed): # Load information from java_stage to the instance. estimator = JavaParams._from_java(java_stage.getEstimator()) evaluator = JavaParams._from_java(java_stage.getEvaluator()) - epms = [estimator._transfer_param_map_from_java(epm) - for epm in java_stage.getEstimatorParamMaps()] + if isinstance(estimator, JavaEstimator): + epms = [estimator._transfer_param_map_from_java(epm) + for epm in java_stage.getEstimatorParamMaps()] + elif MetaAlgorithmReadWrite.isMetaEstimator(estimator): + # Meta estimator such as Pipeline, OneVsRest + epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java( + estimator, java_stage.getEstimatorParamMaps()) + else: + raise ValueError('Unsupported estimator used in tuning: ' + str(estimator)) + return estimator, epms, evaluator def _to_java_impl(self): @@ -198,15 +211,82 @@ class _ValidatorParams(HasSeed): gateway = SparkContext._gateway cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap - java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) - for idx, epm in enumerate(self.getEstimatorParamMaps()): - java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + estimator = self.getEstimator() + if isinstance(estimator, JavaEstimator): + java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) + for idx, epm in enumerate(self.getEstimatorParamMaps()): + java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + elif MetaAlgorithmReadWrite.isMetaEstimator(estimator): + # Meta estimator such as Pipeline, OneVsRest + java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java( + estimator, self.getEstimatorParamMaps()) + else: + raise ValueError('Unsupported estimator used in tuning: ' + str(estimator)) java_estimator = self.getEstimator()._to_java() java_evaluator = self.getEvaluator()._to_java() return java_estimator, java_epms, java_evaluator +class _ValidatorSharedReadWrite: + @staticmethod + def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps): + pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) + stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages)) + sc = SparkContext._active_spark_context + + paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap + javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps)) + + for idx, pyParamMap in enumerate(pyParamMaps): + javaParamMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for pyParam, pyValue in pyParamMap.items(): + javaParam = None + for pyStage, javaStage in stagePairs: + if pyStage._testOwnParam(pyParam.parent, pyParam.name): + javaParam = javaStage.getParam(pyParam.name) + break + if javaParam is None: + raise ValueError('Resolve param in estimatorParamMaps failed: ' + str(pyParam)) + if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"): + javaValue = pyValue._to_java() + else: + javaValue = _py2java(sc, pyValue) + pair = javaParam.w(javaValue) + javaParamMap.put([pair]) + javaParamMaps[idx] = javaParamMap + return javaParamMaps + + @staticmethod + def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps): + pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) + stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages)) + sc = SparkContext._active_spark_context + pyParamMaps = [] + for javaParamMap in javaParamMaps: + pyParamMap = dict() + for javaPair in javaParamMap.toList(): + javaParam = javaPair.param() + pyParam = None + for pyStage, javaStage in stagePairs: + if pyStage._testOwnParam(javaParam.parent(), javaParam.name()): + pyParam = pyStage.getParam(javaParam.name()) + if pyParam is None: + raise ValueError('Resolve param in estimatorParamMaps failed: ' + + javaParam.parent() + '.' + javaParam.name()) + javaValue = javaPair.value() + if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue): + # Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class + # and Estimator/Transformer class which implements `_from_java` static method + # (such as OneVsRest, Pipeline class). + pyValue = JavaParams._from_java(javaValue) + else: + pyValue = _java2py(sc, javaValue) + pyParamMap[pyParam] = pyValue + pyParamMaps.append(pyParamMap) + return pyParamMaps + + class _CrossValidatorParams(_ValidatorParams): """ Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`. diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a7b5a79d75..a34bfb5348 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -592,3 +592,41 @@ class HasTrainingSummary(object): no summary exists. """ return (self._call_java("summary")) + + +class MetaAlgorithmReadWrite: + + @staticmethod + def isMetaEstimator(pyInstance): + from pyspark.ml import Estimator, Pipeline + from pyspark.ml.tuning import _ValidatorParams + from pyspark.ml.classification import OneVsRest + return isinstance(pyInstance, Pipeline) or isinstance(pyInstance, OneVsRest) or \ + (isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams)) + + @staticmethod + def getAllNestedStages(pyInstance): + from pyspark.ml import Pipeline, PipelineModel + from pyspark.ml.tuning import _ValidatorParams + from pyspark.ml.classification import OneVsRest, OneVsRestModel + + # TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel + # support pipelineModel property. + if isinstance(pyInstance, Pipeline): + pySubStages = pyInstance.getStages() + elif isinstance(pyInstance, PipelineModel): + pySubStages = pyInstance.stages + elif isinstance(pyInstance, _ValidatorParams): + raise ValueError('PySpark does not support nested validator.') + elif isinstance(pyInstance, OneVsRest): + pySubStages = [pyInstance.getClassifier()] + elif isinstance(pyInstance, OneVsRestModel): + pySubStages = [pyInstance.getClassifier()] + pyInstance.models + else: + pySubStages = [] + + nestedStages = [] + for pySubStage in pySubStages: + nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage)) + + return [pyInstance] + nestedStages diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi index d0781b2e26..e2496e181f 100644 --- a/python/pyspark/ml/util.pyi +++ b/python/pyspark/ml/util.pyi @@ -126,3 +126,9 @@ class HasTrainingSummary(Generic[S]): def hasSummary(self) -> bool: ... @property def summary(self) -> S: ... + +class MetaAlgorithmReadWrite: + @staticmethod + def isMetaEstimator(pyInstance: Any) -> bool: ... + @staticmethod + def getAllNestedStages(pyInstance: Any) -> list: ...