[SPARK-33592] Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading

### What changes were proposed in this pull request?
Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading

When saving validator estimatorParamMaps, will check all nested stages in tuned estimator to get correct param parent.

Two typical cases to manually test:
~~~python
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression()
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])

paramGrid = ParamGridBuilder() \
    .addGrid(hashingTF.numFeatures, [10, 100]) \
    .addGrid(lr.maxIter, [100, 200]) \
    .build()
tvs = TrainValidationSplit(estimator=pipeline,
                           estimatorParamMaps=paramGrid,
                           evaluator=MulticlassClassificationEvaluator())

tvs.save(tvsPath)
loadedTvs = TrainValidationSplit.load(tvsPath)

# check `loadedTvs.getEstimatorParamMaps()` restored correctly.
~~~

~~~python
lr = LogisticRegression()
ova = OneVsRest(classifier=lr)
grid = ParamGridBuilder().addGrid(lr.maxIter, [100, 200]).build()
evaluator = MulticlassClassificationEvaluator()
tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)

tvs.save(tvsPath)
loadedTvs = TrainValidationSplit.load(tvsPath)

# check `loadedTvs.getEstimatorParamMaps()` restored correctly.
~~~

### Why are the changes needed?
Bug fix.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unit test.

Closes #30539 from WeichenXu123/fix_tuning_param_maps_io.

Authored-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
This commit is contained in:
Weichen Xu 2020-12-01 09:36:42 +08:00 committed by Ruifeng Zheng
parent aeb3649fb9
commit 80161238fe
9 changed files with 268 additions and 107 deletions

View file

@ -564,6 +564,7 @@ pyspark_ml = Module(
"pyspark.ml.tests.test_stat",
"pyspark.ml.tests.test_training_summary",
"pyspark.ml.tests.test_tuning",
"pyspark.ml.tests.test_util",
"pyspark.ml.tests.test_wrapper",
],
excluded_python_implementations=[

View file

@ -36,7 +36,7 @@ from pyspark.ml.base import _PredictorParams
from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary
from pyspark.ml.wrapper import JavaParams, \
JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc, _java2py, _py2java
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import Vectors
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, when
@ -2991,50 +2991,6 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
return _java_obj
def _make_java_param_pair(self, param, value):
"""
Makes a Java param pair.
"""
sc = SparkContext._active_spark_context
param = self._resolveParam(param)
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
self.uid)
java_param = _java_obj.getParam(param.name)
if isinstance(value, JavaParams):
# used in the case of an estimator having another estimator as a parameter
# the reason why this is not in _py2java in common.py is that importing
# Estimator and Model in common.py results in a circular import with inherit_doc
java_value = value._to_java()
else:
java_value = _py2java(sc, value)
return java_param.w(java_value)
def _transfer_param_map_to_java(self, pyParamMap):
"""
Transforms a Python ParamMap into a Java ParamMap.
"""
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
for param in self.params:
if param in pyParamMap:
pair = self._make_java_param_pair(param, pyParamMap[param])
paramMap.put([pair])
return paramMap
def _transfer_param_map_from_java(self, javaParamMap):
"""
Transforms a Java ParamMap into a Python ParamMap.
"""
sc = SparkContext._active_spark_context
paramMap = dict()
for pair in javaParamMap.toList():
param = pair.param()
if self.hasParam(str(param.name())):
if param.name() == "classifier":
paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
else:
paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
return paramMap
class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
"""

View file

@ -437,6 +437,12 @@ class Params(Identifiable, metaclass=ABCMeta):
else:
raise ValueError("Cannot resolve %r as a param." % param)
def _testOwnParam(self, param_parent, param_name):
"""
Test the ownership. Return True or False
"""
return self.uid == param_parent and self.hasParam(param_name)
@staticmethod
def _dummy():
"""

View file

@ -21,8 +21,8 @@ from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
DefaultParamsReader, DefaultParamsWriter, MLWriter, MLReader, JavaMLWritable
from pyspark.ml.wrapper import JavaParams, JavaWrapper
from pyspark.ml.common import inherit_doc, _java2py, _py2java
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc
@inherit_doc
@ -190,55 +190,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
return _java_obj
def _make_java_param_pair(self, param, value):
"""
Makes a Java param pair.
"""
sc = SparkContext._active_spark_context
param = self._resolveParam(param)
java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
if isinstance(value, Params) and hasattr(value, "_to_java"):
# Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which
# implements `_to_java` method (such as OneVsRest, Pipeline object) to java object.
# used in the case of an estimator having another estimator as a parameter
# the reason why this is not in _py2java in common.py is that importing
# Estimator and Model in common.py results in a circular import with inherit_doc
java_value = value._to_java()
else:
java_value = _py2java(sc, value)
return java_param.w(java_value)
def _transfer_param_map_to_java(self, pyParamMap):
"""
Transforms a Python ParamMap into a Java ParamMap.
"""
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
for param in self.params:
if param in pyParamMap:
pair = self._make_java_param_pair(param, pyParamMap[param])
paramMap.put([pair])
return paramMap
def _transfer_param_map_from_java(self, javaParamMap):
"""
Transforms a Java ParamMap into a Python ParamMap.
"""
sc = SparkContext._active_spark_context
paramMap = dict()
for pair in javaParamMap.toList():
param = pair.param()
if self.hasParam(str(param.name())):
java_obj = pair.value()
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj):
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
# and Estimator/Transformer class which implements `_from_java` static method
# (such as OneVsRest, Pipeline class).
py_obj = JavaParams._from_java(java_obj)
else:
py_obj = _java2py(sc, java_obj)
paramMap[self.getParam(param.name())] = py_obj
return paramMap
@inherit_doc
class PipelineWriter(MLWriter):

View file

@ -73,7 +73,21 @@ class ParamGridBuilderTests(SparkSessionTestCase):
.build())
class CrossValidatorTests(SparkSessionTestCase):
class ValidatorTestUtilsMixin:
def assert_param_maps_equal(self, paramMaps1, paramMaps2):
self.assertEqual(len(paramMaps1), len(paramMaps2))
for paramMap1, paramMap2 in zip(paramMaps1, paramMaps2):
self.assertEqual(set(paramMap1.keys()), set(paramMap2.keys()))
for param in paramMap1.keys():
v1 = paramMap1[param]
v2 = paramMap2[param]
if isinstance(v1, Params):
self.assertEqual(v1.uid, v2.uid)
else:
self.assertEqual(v1, v2)
class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
def test_copy(self):
dataset = self.spark.createDataFrame([
@ -256,7 +270,7 @@ class CrossValidatorTests(SparkSessionTestCase):
loadedCV = CrossValidator.load(cvPath)
self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
# test save/load of CrossValidatorModel
cvModelPath = temp_path + "/cvModel"
@ -351,6 +365,7 @@ class CrossValidatorTests(SparkSessionTestCase):
cvPath = temp_path + "/cv"
cv.save(cvPath)
loadedCV = CrossValidator.load(cvPath)
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), grid)
self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
@ -367,6 +382,7 @@ class CrossValidatorTests(SparkSessionTestCase):
cvModelPath = temp_path + "/cvModel"
cvModel.save(cvModelPath)
loadedModel = CrossValidatorModel.load(cvModelPath)
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
def test_save_load_pipeline_estimator(self):
@ -401,6 +417,11 @@ class CrossValidatorTests(SparkSessionTestCase):
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator(),
numFolds=2) # use 3+ folds in practice
cvPath = temp_path + "/cv"
crossval.save(cvPath)
loadedCV = CrossValidator.load(cvPath)
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), paramGrid)
self.assertEqual(loadedCV.getEstimator().uid, crossval.getEstimator().uid)
# Run cross-validation, and choose the best set of parameters.
cvModel = crossval.fit(training)
@ -421,6 +442,11 @@ class CrossValidatorTests(SparkSessionTestCase):
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator(),
numFolds=2) # use 3+ folds in practice
cv2Path = temp_path + "/cv2"
crossval2.save(cv2Path)
loadedCV2 = CrossValidator.load(cv2Path)
self.assert_param_maps_equal(loadedCV2.getEstimatorParamMaps(), paramGrid)
self.assertEqual(loadedCV2.getEstimator().uid, crossval2.getEstimator().uid)
# Run cross-validation, and choose the best set of parameters.
cvModel2 = crossval2.fit(training)
@ -511,7 +537,7 @@ class CrossValidatorTests(SparkSessionTestCase):
cv.fit(dataset_with_folds)
class TrainValidationSplitTests(SparkSessionTestCase):
class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
def test_fit_minimize_metric(self):
dataset = self.spark.createDataFrame([
@ -632,7 +658,8 @@ class TrainValidationSplitTests(SparkSessionTestCase):
loadedTvs = TrainValidationSplit.load(tvsPath)
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
self.assert_param_maps_equal(
loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
tvsModelPath = temp_path + "/tvsModel"
tvsModel.save(tvsModelPath)
@ -713,6 +740,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
tvsPath = temp_path + "/tvs"
tvs.save(tvsPath)
loadedTvs = TrainValidationSplit.load(tvsPath)
self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), grid)
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
@ -728,6 +756,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
tvsModelPath = temp_path + "/tvsModel"
tvsModel.save(tvsModelPath)
loadedModel = TrainValidationSplitModel.load(tvsModelPath)
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
def test_save_load_pipeline_estimator(self):
@ -761,6 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase):
tvs = TrainValidationSplit(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator())
tvsPath = temp_path + "/tvs"
tvs.save(tvsPath)
loadedTvs = TrainValidationSplit.load(tvsPath)
self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), paramGrid)
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
# Run train validation split, and choose the best set of parameters.
tvsModel = tvs.fit(training)
@ -780,6 +814,11 @@ class TrainValidationSplitTests(SparkSessionTestCase):
tvs2 = TrainValidationSplit(estimator=nested_pipeline,
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator())
tvs2Path = temp_path + "/tvs2"
tvs2.save(tvs2Path)
loadedTvs2 = TrainValidationSplit.load(tvs2Path)
self.assert_param_maps_equal(loadedTvs2.getEstimatorParamMaps(), paramGrid)
self.assertEqual(loadedTvs2.getEstimator().uid, tvs2.getEstimator().uid)
# Run train validation split, and choose the best set of parameters.
tvsModel2 = tvs2.fit(training)

View file

@ -0,0 +1,84 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vectors
from pyspark.ml.util import MetaAlgorithmReadWrite
from pyspark.testing.mlutils import SparkSessionTestCase
class MetaAlgorithmReadWriteTests(SparkSessionTestCase):
def test_getAllNestedStages(self):
def _check_uid_set_equal(stages, expected_stages):
uids = set(map(lambda x: x.uid, stages))
expected_uids = set(map(lambda x: x.uid, expected_stages))
self.assertEqual(uids, expected_uids)
df1 = self.spark.createDataFrame([
(Vectors.dense([1., 2.]), 1.0),
(Vectors.dense([-1., -2.]), 0.0),
], ['features', 'label'])
df2 = self.spark.createDataFrame([
(1., 2., 1.0),
(1., 2., 0.0),
], ['a', 'b', 'label'])
vs = VectorAssembler(inputCols=['a', 'b'], outputCol='features')
lr = LogisticRegression()
pipeline = Pipeline(stages=[vs, lr])
pipelineModel = pipeline.fit(df2)
ova = OneVsRest(classifier=lr)
ovaModel = ova.fit(df1)
ova_pipeline = Pipeline(stages=[vs, ova])
nested_pipeline = Pipeline(stages=[ova_pipeline])
_check_uid_set_equal(
MetaAlgorithmReadWrite.getAllNestedStages(pipeline),
[pipeline, vs, lr]
)
_check_uid_set_equal(
MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel),
[pipelineModel] + pipelineModel.stages
)
_check_uid_set_equal(
MetaAlgorithmReadWrite.getAllNestedStages(ova),
[ova, lr]
)
_check_uid_set_equal(
MetaAlgorithmReadWrite.getAllNestedStages(ovaModel),
[ovaModel, lr] + ovaModel.models
)
_check_uid_set_equal(
MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline),
[nested_pipeline, ova_pipeline, vs, ova, lr]
)
if __name__ == "__main__":
from pyspark.ml.tests.test_util import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -26,8 +26,9 @@ from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java, _java2py
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
MetaAlgorithmReadWrite
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
from pyspark.sql.types import BooleanType
@ -64,6 +65,10 @@ def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
def singleTask():
index, model = next(modelIter)
# TODO: duplicate evaluator to take extra params from input
# Note: Supporting tuning params in evaluator need update method
# `MetaAlgorithmReadWrite.getAllNestedStages`, make it return
# all nested stages and evaluators
metric = eva.evaluate(model.transform(validation, epm[index]))
return index, metric, model if collectSubModel else None
@ -186,8 +191,16 @@ class _ValidatorParams(HasSeed):
# Load information from java_stage to the instance.
estimator = JavaParams._from_java(java_stage.getEstimator())
evaluator = JavaParams._from_java(java_stage.getEvaluator())
epms = [estimator._transfer_param_map_from_java(epm)
for epm in java_stage.getEstimatorParamMaps()]
if isinstance(estimator, JavaEstimator):
epms = [estimator._transfer_param_map_from_java(epm)
for epm in java_stage.getEstimatorParamMaps()]
elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
# Meta estimator such as Pipeline, OneVsRest
epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java(
estimator, java_stage.getEstimatorParamMaps())
else:
raise ValueError('Unsupported estimator used in tuning: ' + str(estimator))
return estimator, epms, evaluator
def _to_java_impl(self):
@ -198,15 +211,82 @@ class _ValidatorParams(HasSeed):
gateway = SparkContext._gateway
cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
for idx, epm in enumerate(self.getEstimatorParamMaps()):
java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
estimator = self.getEstimator()
if isinstance(estimator, JavaEstimator):
java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
for idx, epm in enumerate(self.getEstimatorParamMaps()):
java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
# Meta estimator such as Pipeline, OneVsRest
java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java(
estimator, self.getEstimatorParamMaps())
else:
raise ValueError('Unsupported estimator used in tuning: ' + str(estimator))
java_estimator = self.getEstimator()._to_java()
java_evaluator = self.getEvaluator()._to_java()
return java_estimator, java_epms, java_evaluator
class _ValidatorSharedReadWrite:
@staticmethod
def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps):
pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages))
sc = SparkContext._active_spark_context
paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps))
for idx, pyParamMap in enumerate(pyParamMaps):
javaParamMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
for pyParam, pyValue in pyParamMap.items():
javaParam = None
for pyStage, javaStage in stagePairs:
if pyStage._testOwnParam(pyParam.parent, pyParam.name):
javaParam = javaStage.getParam(pyParam.name)
break
if javaParam is None:
raise ValueError('Resolve param in estimatorParamMaps failed: ' + str(pyParam))
if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"):
javaValue = pyValue._to_java()
else:
javaValue = _py2java(sc, pyValue)
pair = javaParam.w(javaValue)
javaParamMap.put([pair])
javaParamMaps[idx] = javaParamMap
return javaParamMaps
@staticmethod
def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps):
pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages))
sc = SparkContext._active_spark_context
pyParamMaps = []
for javaParamMap in javaParamMaps:
pyParamMap = dict()
for javaPair in javaParamMap.toList():
javaParam = javaPair.param()
pyParam = None
for pyStage, javaStage in stagePairs:
if pyStage._testOwnParam(javaParam.parent(), javaParam.name()):
pyParam = pyStage.getParam(javaParam.name())
if pyParam is None:
raise ValueError('Resolve param in estimatorParamMaps failed: ' +
javaParam.parent() + '.' + javaParam.name())
javaValue = javaPair.value()
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue):
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
# and Estimator/Transformer class which implements `_from_java` static method
# (such as OneVsRest, Pipeline class).
pyValue = JavaParams._from_java(javaValue)
else:
pyValue = _java2py(sc, javaValue)
pyParamMap[pyParam] = pyValue
pyParamMaps.append(pyParamMap)
return pyParamMaps
class _CrossValidatorParams(_ValidatorParams):
"""
Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.

View file

@ -592,3 +592,41 @@ class HasTrainingSummary(object):
no summary exists.
"""
return (self._call_java("summary"))
class MetaAlgorithmReadWrite:
@staticmethod
def isMetaEstimator(pyInstance):
from pyspark.ml import Estimator, Pipeline
from pyspark.ml.tuning import _ValidatorParams
from pyspark.ml.classification import OneVsRest
return isinstance(pyInstance, Pipeline) or isinstance(pyInstance, OneVsRest) or \
(isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams))
@staticmethod
def getAllNestedStages(pyInstance):
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.tuning import _ValidatorParams
from pyspark.ml.classification import OneVsRest, OneVsRestModel
# TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel
# support pipelineModel property.
if isinstance(pyInstance, Pipeline):
pySubStages = pyInstance.getStages()
elif isinstance(pyInstance, PipelineModel):
pySubStages = pyInstance.stages
elif isinstance(pyInstance, _ValidatorParams):
raise ValueError('PySpark does not support nested validator.')
elif isinstance(pyInstance, OneVsRest):
pySubStages = [pyInstance.getClassifier()]
elif isinstance(pyInstance, OneVsRestModel):
pySubStages = [pyInstance.getClassifier()] + pyInstance.models
else:
pySubStages = []
nestedStages = []
for pySubStage in pySubStages:
nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))
return [pyInstance] + nestedStages

View file

@ -126,3 +126,9 @@ class HasTrainingSummary(Generic[S]):
def hasSummary(self) -> bool: ...
@property
def summary(self) -> S: ...
class MetaAlgorithmReadWrite:
@staticmethod
def isMetaEstimator(pyInstance: Any) -> bool: ...
@staticmethod
def getAllNestedStages(pyInstance: Any) -> list: ...