diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 868e4a5d23..5d8b714711 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -564,6 +564,7 @@ pyspark_ml = Module( "pyspark.ml.tests.test_stat", "pyspark.ml.tests.test_training_summary", "pyspark.ml.tests.test_tuning", + "pyspark.ml.tests.test_util", "pyspark.ml.tests.test_wrapper", ], excluded_python_implementations=[ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 50882fc895..763038ede8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,7 @@ from pyspark.ml.base import _PredictorParams from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary from pyspark.ml.wrapper import JavaParams, \ JavaPredictor, JavaPredictionModel, JavaWrapper -from pyspark.ml.common import inherit_doc, _java2py, _py2java +from pyspark.ml.common import inherit_doc from pyspark.ml.linalg import Vectors from pyspark.sql import DataFrame from pyspark.sql.functions import udf, when @@ -2991,50 +2991,6 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav _java_obj.setRawPredictionCol(self.getRawPredictionCol()) return _java_obj - def _make_java_param_pair(self, param, value): - """ - Makes a Java param pair. - """ - sc = SparkContext._active_spark_context - param = self._resolveParam(param) - _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest", - self.uid) - java_param = _java_obj.getParam(param.name) - if isinstance(value, JavaParams): - # used in the case of an estimator having another estimator as a parameter - # the reason why this is not in _py2java in common.py is that importing - # Estimator and Model in common.py results in a circular import with inherit_doc - java_value = value._to_java() - else: - java_value = _py2java(sc, value) - return java_param.w(java_value) - - def _transfer_param_map_to_java(self, pyParamMap): - """ - Transforms a Python ParamMap into a Java ParamMap. - """ - paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") - for param in self.params: - if param in pyParamMap: - pair = self._make_java_param_pair(param, pyParamMap[param]) - paramMap.put([pair]) - return paramMap - - def _transfer_param_map_from_java(self, javaParamMap): - """ - Transforms a Java ParamMap into a Python ParamMap. - """ - sc = SparkContext._active_spark_context - paramMap = dict() - for pair in javaParamMap.toList(): - param = pair.param() - if self.hasParam(str(param.name())): - if param.name() == "classifier": - paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value()) - else: - paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) - return paramMap - class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index f2381a4c42..3eab6607aa 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -437,6 +437,12 @@ class Params(Identifiable, metaclass=ABCMeta): else: raise ValueError("Cannot resolve %r as a param." % param) + def _testOwnParam(self, param_parent, param_name): + """ + Test the ownership. Return True or False + """ + return self.uid == param_parent and self.hasParam(param_name) + @staticmethod def _dummy(): """ diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a6471a8dd1..b0aa735709 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -21,8 +21,8 @@ from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \ DefaultParamsReader, DefaultParamsWriter, MLWriter, MLReader, JavaMLWritable -from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.ml.common import inherit_doc, _java2py, _py2java +from pyspark.ml.wrapper import JavaParams +from pyspark.ml.common import inherit_doc @inherit_doc @@ -190,55 +190,6 @@ class Pipeline(Estimator, MLReadable, MLWritable): return _java_obj - def _make_java_param_pair(self, param, value): - """ - Makes a Java param pair. - """ - sc = SparkContext._active_spark_context - param = self._resolveParam(param) - java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc) - if isinstance(value, Params) and hasattr(value, "_to_java"): - # Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which - # implements `_to_java` method (such as OneVsRest, Pipeline object) to java object. - # used in the case of an estimator having another estimator as a parameter - # the reason why this is not in _py2java in common.py is that importing - # Estimator and Model in common.py results in a circular import with inherit_doc - java_value = value._to_java() - else: - java_value = _py2java(sc, value) - return java_param.w(java_value) - - def _transfer_param_map_to_java(self, pyParamMap): - """ - Transforms a Python ParamMap into a Java ParamMap. - """ - paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") - for param in self.params: - if param in pyParamMap: - pair = self._make_java_param_pair(param, pyParamMap[param]) - paramMap.put([pair]) - return paramMap - - def _transfer_param_map_from_java(self, javaParamMap): - """ - Transforms a Java ParamMap into a Python ParamMap. - """ - sc = SparkContext._active_spark_context - paramMap = dict() - for pair in javaParamMap.toList(): - param = pair.param() - if self.hasParam(str(param.name())): - java_obj = pair.value() - if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj): - # Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class - # and Estimator/Transformer class which implements `_from_java` static method - # (such as OneVsRest, Pipeline class). - py_obj = JavaParams._from_java(java_obj) - else: - py_obj = _java2py(sc, java_obj) - paramMap[self.getParam(param.name())] = py_obj - return paramMap - @inherit_doc class PipelineWriter(MLWriter): diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index ced32c07f2..ebd7457e4d 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -73,7 +73,21 @@ class ParamGridBuilderTests(SparkSessionTestCase): .build()) -class CrossValidatorTests(SparkSessionTestCase): +class ValidatorTestUtilsMixin: + def assert_param_maps_equal(self, paramMaps1, paramMaps2): + self.assertEqual(len(paramMaps1), len(paramMaps2)) + for paramMap1, paramMap2 in zip(paramMaps1, paramMaps2): + self.assertEqual(set(paramMap1.keys()), set(paramMap2.keys())) + for param in paramMap1.keys(): + v1 = paramMap1[param] + v2 = paramMap2[param] + if isinstance(v1, Params): + self.assertEqual(v1.uid, v2.uid) + else: + self.assertEqual(v1, v2) + + +class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): def test_copy(self): dataset = self.spark.createDataFrame([ @@ -256,7 +270,7 @@ class CrossValidatorTests(SparkSessionTestCase): loadedCV = CrossValidator.load(cvPath) self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) # test save/load of CrossValidatorModel cvModelPath = temp_path + "/cvModel" @@ -351,6 +365,7 @@ class CrossValidatorTests(SparkSessionTestCase): cvPath = temp_path + "/cv" cv.save(cvPath) loadedCV = CrossValidator.load(cvPath) + self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), grid) self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) @@ -367,6 +382,7 @@ class CrossValidatorTests(SparkSessionTestCase): cvModelPath = temp_path + "/cvModel" cvModel.save(cvModelPath) loadedModel = CrossValidatorModel.load(cvModelPath) + self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) def test_save_load_pipeline_estimator(self): @@ -401,6 +417,11 @@ class CrossValidatorTests(SparkSessionTestCase): estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator(), numFolds=2) # use 3+ folds in practice + cvPath = temp_path + "/cv" + crossval.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedCV.getEstimator().uid, crossval.getEstimator().uid) # Run cross-validation, and choose the best set of parameters. cvModel = crossval.fit(training) @@ -421,6 +442,11 @@ class CrossValidatorTests(SparkSessionTestCase): estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator(), numFolds=2) # use 3+ folds in practice + cv2Path = temp_path + "/cv2" + crossval2.save(cv2Path) + loadedCV2 = CrossValidator.load(cv2Path) + self.assert_param_maps_equal(loadedCV2.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedCV2.getEstimator().uid, crossval2.getEstimator().uid) # Run cross-validation, and choose the best set of parameters. cvModel2 = crossval2.fit(training) @@ -511,7 +537,7 @@ class CrossValidatorTests(SparkSessionTestCase): cv.fit(dataset_with_folds) -class TrainValidationSplitTests(SparkSessionTestCase): +class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): def test_fit_minimize_metric(self): dataset = self.spark.createDataFrame([ @@ -632,7 +658,8 @@ class TrainValidationSplitTests(SparkSessionTestCase): loadedTvs = TrainValidationSplit.load(tvsPath) self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + self.assert_param_maps_equal( + loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) tvsModelPath = temp_path + "/tvsModel" tvsModel.save(tvsModelPath) @@ -713,6 +740,7 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvsPath = temp_path + "/tvs" tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) + self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), grid) self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) @@ -728,6 +756,7 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvsModelPath = temp_path + "/tvsModel" tvsModel.save(tvsModelPath) loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) def test_save_load_pipeline_estimator(self): @@ -761,6 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvs = TrainValidationSplit(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) # Run train validation split, and choose the best set of parameters. tvsModel = tvs.fit(training) @@ -780,6 +814,11 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvs2 = TrainValidationSplit(estimator=nested_pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) + tvs2Path = temp_path + "/tvs2" + tvs2.save(tvs2Path) + loadedTvs2 = TrainValidationSplit.load(tvs2Path) + self.assert_param_maps_equal(loadedTvs2.getEstimatorParamMaps(), paramGrid) + self.assertEqual(loadedTvs2.getEstimator().uid, tvs2.getEstimator().uid) # Run train validation split, and choose the best set of parameters. tvsModel2 = tvs2.fit(training) diff --git a/python/pyspark/ml/tests/test_util.py b/python/pyspark/ml/tests/test_util.py new file mode 100644 index 0000000000..498a649e48 --- /dev/null +++ b/python/pyspark/ml/tests/test_util.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression, OneVsRest +from pyspark.ml.feature import VectorAssembler +from pyspark.ml.linalg import Vectors +from pyspark.ml.util import MetaAlgorithmReadWrite +from pyspark.testing.mlutils import SparkSessionTestCase + + +class MetaAlgorithmReadWriteTests(SparkSessionTestCase): + + def test_getAllNestedStages(self): + def _check_uid_set_equal(stages, expected_stages): + uids = set(map(lambda x: x.uid, stages)) + expected_uids = set(map(lambda x: x.uid, expected_stages)) + self.assertEqual(uids, expected_uids) + + df1 = self.spark.createDataFrame([ + (Vectors.dense([1., 2.]), 1.0), + (Vectors.dense([-1., -2.]), 0.0), + ], ['features', 'label']) + df2 = self.spark.createDataFrame([ + (1., 2., 1.0), + (1., 2., 0.0), + ], ['a', 'b', 'label']) + vs = VectorAssembler(inputCols=['a', 'b'], outputCol='features') + lr = LogisticRegression() + pipeline = Pipeline(stages=[vs, lr]) + pipelineModel = pipeline.fit(df2) + ova = OneVsRest(classifier=lr) + ovaModel = ova.fit(df1) + + ova_pipeline = Pipeline(stages=[vs, ova]) + nested_pipeline = Pipeline(stages=[ova_pipeline]) + + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(pipeline), + [pipeline, vs, lr] + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel), + [pipelineModel] + pipelineModel.stages + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(ova), + [ova, lr] + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(ovaModel), + [ovaModel, lr] + ovaModel.models + ) + _check_uid_set_equal( + MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline), + [nested_pipeline, ova_pipeline, vs, ova, lr] + ) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_util import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6f4ad99484..2b5a9857b0 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -26,8 +26,9 @@ from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java, _java2py from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed -from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader -from pyspark.ml.wrapper import JavaParams +from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \ + MetaAlgorithmReadWrite +from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper from pyspark.sql.functions import col, lit, rand, UserDefinedFunction from pyspark.sql.types import BooleanType @@ -64,6 +65,10 @@ def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): def singleTask(): index, model = next(modelIter) + # TODO: duplicate evaluator to take extra params from input + # Note: Supporting tuning params in evaluator need update method + # `MetaAlgorithmReadWrite.getAllNestedStages`, make it return + # all nested stages and evaluators metric = eva.evaluate(model.transform(validation, epm[index])) return index, metric, model if collectSubModel else None @@ -186,8 +191,16 @@ class _ValidatorParams(HasSeed): # Load information from java_stage to the instance. estimator = JavaParams._from_java(java_stage.getEstimator()) evaluator = JavaParams._from_java(java_stage.getEvaluator()) - epms = [estimator._transfer_param_map_from_java(epm) - for epm in java_stage.getEstimatorParamMaps()] + if isinstance(estimator, JavaEstimator): + epms = [estimator._transfer_param_map_from_java(epm) + for epm in java_stage.getEstimatorParamMaps()] + elif MetaAlgorithmReadWrite.isMetaEstimator(estimator): + # Meta estimator such as Pipeline, OneVsRest + epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java( + estimator, java_stage.getEstimatorParamMaps()) + else: + raise ValueError('Unsupported estimator used in tuning: ' + str(estimator)) + return estimator, epms, evaluator def _to_java_impl(self): @@ -198,15 +211,82 @@ class _ValidatorParams(HasSeed): gateway = SparkContext._gateway cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap - java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) - for idx, epm in enumerate(self.getEstimatorParamMaps()): - java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + estimator = self.getEstimator() + if isinstance(estimator, JavaEstimator): + java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) + for idx, epm in enumerate(self.getEstimatorParamMaps()): + java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + elif MetaAlgorithmReadWrite.isMetaEstimator(estimator): + # Meta estimator such as Pipeline, OneVsRest + java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java( + estimator, self.getEstimatorParamMaps()) + else: + raise ValueError('Unsupported estimator used in tuning: ' + str(estimator)) java_estimator = self.getEstimator()._to_java() java_evaluator = self.getEvaluator()._to_java() return java_estimator, java_epms, java_evaluator +class _ValidatorSharedReadWrite: + @staticmethod + def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps): + pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) + stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages)) + sc = SparkContext._active_spark_context + + paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap + javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps)) + + for idx, pyParamMap in enumerate(pyParamMaps): + javaParamMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for pyParam, pyValue in pyParamMap.items(): + javaParam = None + for pyStage, javaStage in stagePairs: + if pyStage._testOwnParam(pyParam.parent, pyParam.name): + javaParam = javaStage.getParam(pyParam.name) + break + if javaParam is None: + raise ValueError('Resolve param in estimatorParamMaps failed: ' + str(pyParam)) + if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"): + javaValue = pyValue._to_java() + else: + javaValue = _py2java(sc, pyValue) + pair = javaParam.w(javaValue) + javaParamMap.put([pair]) + javaParamMaps[idx] = javaParamMap + return javaParamMaps + + @staticmethod + def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps): + pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) + stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages)) + sc = SparkContext._active_spark_context + pyParamMaps = [] + for javaParamMap in javaParamMaps: + pyParamMap = dict() + for javaPair in javaParamMap.toList(): + javaParam = javaPair.param() + pyParam = None + for pyStage, javaStage in stagePairs: + if pyStage._testOwnParam(javaParam.parent(), javaParam.name()): + pyParam = pyStage.getParam(javaParam.name()) + if pyParam is None: + raise ValueError('Resolve param in estimatorParamMaps failed: ' + + javaParam.parent() + '.' + javaParam.name()) + javaValue = javaPair.value() + if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue): + # Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class + # and Estimator/Transformer class which implements `_from_java` static method + # (such as OneVsRest, Pipeline class). + pyValue = JavaParams._from_java(javaValue) + else: + pyValue = _java2py(sc, javaValue) + pyParamMap[pyParam] = pyValue + pyParamMaps.append(pyParamMap) + return pyParamMaps + + class _CrossValidatorParams(_ValidatorParams): """ Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`. diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a7b5a79d75..a34bfb5348 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -592,3 +592,41 @@ class HasTrainingSummary(object): no summary exists. """ return (self._call_java("summary")) + + +class MetaAlgorithmReadWrite: + + @staticmethod + def isMetaEstimator(pyInstance): + from pyspark.ml import Estimator, Pipeline + from pyspark.ml.tuning import _ValidatorParams + from pyspark.ml.classification import OneVsRest + return isinstance(pyInstance, Pipeline) or isinstance(pyInstance, OneVsRest) or \ + (isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams)) + + @staticmethod + def getAllNestedStages(pyInstance): + from pyspark.ml import Pipeline, PipelineModel + from pyspark.ml.tuning import _ValidatorParams + from pyspark.ml.classification import OneVsRest, OneVsRestModel + + # TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel + # support pipelineModel property. + if isinstance(pyInstance, Pipeline): + pySubStages = pyInstance.getStages() + elif isinstance(pyInstance, PipelineModel): + pySubStages = pyInstance.stages + elif isinstance(pyInstance, _ValidatorParams): + raise ValueError('PySpark does not support nested validator.') + elif isinstance(pyInstance, OneVsRest): + pySubStages = [pyInstance.getClassifier()] + elif isinstance(pyInstance, OneVsRestModel): + pySubStages = [pyInstance.getClassifier()] + pyInstance.models + else: + pySubStages = [] + + nestedStages = [] + for pySubStage in pySubStages: + nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage)) + + return [pyInstance] + nestedStages diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi index d0781b2e26..e2496e181f 100644 --- a/python/pyspark/ml/util.pyi +++ b/python/pyspark/ml/util.pyi @@ -126,3 +126,9 @@ class HasTrainingSummary(Generic[S]): def hasSummary(self) -> bool: ... @property def summary(self) -> S: ... + +class MetaAlgorithmReadWrite: + @staticmethod + def isMetaEstimator(pyInstance: Any) -> bool: ... + @staticmethod + def getAllNestedStages(pyInstance: Any) -> list: ...