diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 763038ede8..0553a61c6c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,6 +15,7 @@ # limitations under the License. # +import os import operator import sys import uuid @@ -33,7 +34,9 @@ from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \ _HasVarianceImpurity, _TreeClassifierParams from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel from pyspark.ml.base import _PredictorParams -from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary +from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \ + JavaMLReadable, JavaMLReader, JavaMLWritable, JavaMLWriter, \ + MLReader, MLReadable, MLWriter, MLWritable, HasTrainingSummary from pyspark.ml.wrapper import JavaParams, \ JavaPredictor, JavaPredictionModel, JavaWrapper from pyspark.ml.common import inherit_doc @@ -2760,7 +2763,7 @@ class _OneVsRestParams(_ClassifierParams, HasWeightCol): @inherit_doc -class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, JavaMLWritable): +class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWritable): """ Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strategy. @@ -2991,8 +2994,73 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav _java_obj.setRawPredictionCol(self.getRawPredictionCol()) return _java_obj + @classmethod + def read(cls): + return OneVsRestReader(cls) -class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): + def write(self): + if isinstance(self.getClassifier(), JavaMLWritable): + return JavaMLWriter(self) + else: + return OneVsRestWriter(self) + + +class _OneVsRestSharedReadWrite: + @staticmethod + def saveImpl(instance, sc, path, extraMetadata=None): + skipParams = ['classifier'] + jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams, + extraMetadata=extraMetadata) + classifierPath = os.path.join(path, 'classifier') + instance.getClassifier().save(classifierPath) + + @staticmethod + def loadClassifier(path, sc): + classifierPath = os.path.join(path, 'classifier') + return DefaultParamsReader.loadParamsInstance(classifierPath, sc) + + @staticmethod + def validateParams(instance): + elems_to_check = [instance.getClassifier()] + if isinstance(instance, OneVsRestModel): + elems_to_check.extend(instance.models) + + for elem in elems_to_check: + if not isinstance(elem, MLWritable): + raise ValueError(f'OneVsRest write will fail because it contains {elem.uid} ' + f'which is not writable.') + + +@inherit_doc +class OneVsRestReader(MLReader): + def __init__(self, cls): + super(OneVsRestReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if not DefaultParamsReader.isPythonParamsInstance(metadata): + return JavaMLReader(self.cls).load(path) + else: + classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc) + ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid']) + DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=['classifier']) + return ova + + +@inherit_doc +class OneVsRestWriter(MLWriter): + def __init__(self, instance): + super(OneVsRestWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + _OneVsRestSharedReadWrite.validateParams(self.instance) + _OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path) + + +class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable): """ Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for each class. @@ -3023,6 +3091,9 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): def __init__(self, models): super(OneVsRestModel, self).__init__() self.models = models + if not isinstance(models[0], JavaMLWritable): + return + # set java instance java_models = [model._to_java() for model in self.models] sc = SparkContext._active_spark_context java_models_array = JavaWrapper._new_java_array(java_models, @@ -3160,6 +3231,57 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): _java_obj.set("weightCol", self.getWeightCol()) return _java_obj + @classmethod + def read(cls): + return OneVsRestModelReader(cls) + + def write(self): + if all(map(lambda elem: isinstance(elem, JavaMLWritable), + [self.getClassifier()] + self.models)): + return JavaMLWriter(self) + else: + return OneVsRestModelWriter(self) + + +@inherit_doc +class OneVsRestModelReader(MLReader): + def __init__(self, cls): + super(OneVsRestModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if not DefaultParamsReader.isPythonParamsInstance(metadata): + return JavaMLReader(self.cls).load(path) + else: + classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc) + numClasses = metadata['numClasses'] + subModels = [None] * numClasses + for idx in range(numClasses): + subModelPath = os.path.join(path, f'model_{idx}') + subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc) + ovaModel = OneVsRestModel(subModels)._resetUid(metadata['uid']) + ovaModel.set(ovaModel.classifier, classifier) + DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=['classifier']) + return ovaModel + + +@inherit_doc +class OneVsRestModelWriter(MLWriter): + def __init__(self, instance): + super(OneVsRestModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + _OneVsRestSharedReadWrite.validateParams(self.instance) + instance = self.instance + numClasses = len(instance.models) + extraMetadata = {'numClasses': numClasses} + _OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata) + for idx in range(numClasses): + subModelPath = os.path.join(path, f'model_{idx}') + instance.models[idx].save(subModelPath) + @inherit_doc class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable, diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi index c44176a13a..a4a3d21018 100644 --- a/python/pyspark/ml/classification.pyi +++ b/python/pyspark/ml/classification.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Type from pyspark.ml._typing import JM, M, P, T, ParamMap import abc @@ -53,7 +53,8 @@ from pyspark.ml.tree import ( _TreeClassifierParams, _TreeEnsembleModel, ) -from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable +from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable, \ + MLReader, MLReadable, MLWriter, MLWritable from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaWrapper from pyspark.ml.linalg import Matrix, Vector @@ -797,8 +798,8 @@ class OneVsRest( Estimator[OneVsRestModel], _OneVsRestParams, HasParallelism, - JavaMLReadable[OneVsRest], - JavaMLWritable, + MLReadable[OneVsRest], + MLWritable, ): def __init__( self, @@ -832,7 +833,7 @@ class OneVsRest( def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRest: ... class OneVsRestModel( - Model, _OneVsRestParams, JavaMLReadable[OneVsRestModel], JavaMLWritable + Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable ): models: List[Transformer] def __init__(self, models: List[Transformer]) -> None: ... @@ -841,6 +842,26 @@ class OneVsRestModel( def setRawPredictionCol(self, value: str) -> OneVsRestModel: ... def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRestModel: ... +class OneVsRestWriter(MLWriter): + instance: OneVsRest + def __init__(self, instance: OneVsRest) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class OneVsRestReader(MLReader[OneVsRest]): + cls: Type[OneVsRest] + def __init__(self, cls: Type[OneVsRest]) -> None: ... + def load(self, path: str) -> OneVsRest: ... + +class OneVsRestModelWriter(MLWriter): + instance: OneVsRestModel + def __init__(self, instance: OneVsRestModel) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class OneVsRestModelReader(MLReader[OneVsRestModel]): + cls: Type[OneVsRestModel] + def __init__(self, cls: Type[OneVsRestModel]) -> None: ... + def load(self, path: str) -> OneVsRestModel: ... + class FMClassifier( _JavaProbabilisticClassifier[FMClassificationModel], _FactorizationMachinesParams, diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index 0bbcfcdf50..77a6c03096 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -237,6 +237,11 @@ class PersistenceTest(SparkSessionTestCase): self.assertEqual(len(m1.models), len(m2.models)) for x, y in zip(m1.models, m2.models): self._compare_pipelines(x, y) + elif isinstance(m1, Params): + # Test on python backend Estimator/Transformer/Model/Evaluator + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self._compare_params(m1, m2, p) else: raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) @@ -326,14 +331,14 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass - def test_onevsrest(self): + def _run_test_onevsrest(self, LogisticRegressionCls): temp_path = tempfile.mkdtemp() df = self.spark.createDataFrame([(0.0, 0.5, Vectors.dense(1.0, 0.8)), (1.0, 0.5, Vectors.sparse(2, [], [])), (2.0, 1.0, Vectors.dense(0.5, 0.5))] * 10, ["label", "wt", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) + lr = LogisticRegressionCls(maxIter=5, regParam=0.01) ovr = OneVsRest(classifier=lr) def reload_and_compare(ovr, suffix): @@ -350,6 +355,11 @@ class PersistenceTest(SparkSessionTestCase): reload_and_compare(OneVsRest(classifier=lr), "ovr") reload_and_compare(OneVsRest(classifier=lr).setWeightCol("wt"), "ovrw") + def test_onevsrest(self): + from pyspark.testing.mlutils import DummyLogisticRegression + self._run_test_onevsrest(LogisticRegression) + self._run_test_onevsrest(DummyLogisticRegression) + def test_decisiontree_classifier(self): dt = DecisionTreeClassifier(maxDepth=1) path = tempfile.mkdtemp() diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index ebd7457e4d..3cde34facb 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -28,7 +28,8 @@ from pyspark.ml.param import Param, Params from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \ TrainValidationSplit, TrainValidationSplitModel from pyspark.sql.functions import rand -from pyspark.testing.mlutils import SparkSessionTestCase +from pyspark.testing.mlutils import DummyEvaluator, DummyLogisticRegression, \ + DummyLogisticRegressionModel, SparkSessionTestCase class HasInducedError(Params): @@ -201,7 +202,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): for v in param.values(): assert(type(v) == float) - def test_save_load_trained_model(self): + def _run_test_save_load_trained_model(self, LogisticRegressionCls, LogisticRegressionModelCls): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() @@ -212,7 +213,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): (Vectors.dense([0.6]), 1.0), (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) - lr = LogisticRegression() + lr = LogisticRegressionCls() grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() evaluator = BinaryClassificationEvaluator() cv = CrossValidator( @@ -228,7 +229,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): lrModelPath = temp_path + "/lrModel" lrModel.save(lrModelPath) - loadedLrModel = LogisticRegressionModel.load(lrModelPath) + loadedLrModel = LogisticRegressionModelCls.load(lrModelPath) self.assertEqual(loadedLrModel.uid, lrModel.uid) self.assertEqual(loadedLrModel.intercept, lrModel.intercept) @@ -248,7 +249,12 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): loadedCvModel.isSet(param) for param in loadedCvModel.params )) - def test_save_load_simple_estimator(self): + def test_save_load_trained_model(self): + self._run_test_save_load_trained_model(LogisticRegression, LogisticRegressionModel) + self._run_test_save_load_trained_model(DummyLogisticRegression, + DummyLogisticRegressionModel) + + def _run_test_save_load_simple_estimator(self, LogisticRegressionCls, evaluatorCls): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), @@ -258,9 +264,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) - lr = LogisticRegression() + lr = LogisticRegressionCls() grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() + evaluator = evaluatorCls() # test save/load of CrossValidator cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) @@ -278,6 +284,12 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): loadedModel = CrossValidatorModel.load(cvModelPath) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + def test_save_load_simple_estimator(self): + self._run_test_save_load_simple_estimator( + LogisticRegression, BinaryClassificationEvaluator) + self._run_test_save_load_simple_estimator( + DummyLogisticRegression, DummyEvaluator) + def test_parallel_evaluation(self): dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), @@ -343,7 +355,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): for j in range(len(grid)): self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) - def test_save_load_nested_estimator(self): + def _run_test_save_load_nested_estimator(self, LogisticRegressionCls): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), @@ -353,9 +365,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) + ova = OneVsRest(classifier=LogisticRegressionCls()) + lr1 = LogisticRegressionCls().setMaxIter(100) + lr2 = LogisticRegressionCls().setMaxIter(150) grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() evaluator = MulticlassClassificationEvaluator() @@ -385,7 +397,11 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - def test_save_load_pipeline_estimator(self): + def test_save_load_nested_estimator(self): + self._run_test_save_load_nested_estimator(LogisticRegression) + self._run_test_save_load_nested_estimator(DummyLogisticRegression) + + def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls): temp_path = tempfile.mkdtemp() training = self.spark.createDataFrame([ (0, "a b c d e spark", 1.0), @@ -402,9 +418,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(5) - lr2 = LogisticRegression().setMaxIter(10) + ova = OneVsRest(classifier=LogisticRegressionCls()) + lr1 = LogisticRegressionCls().setMaxIter(5) + lr2 = LogisticRegressionCls().setMaxIter(10) pipeline = Pipeline(stages=[tokenizer, hashingTF, ova]) @@ -464,6 +480,10 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin): original_nested_pipeline_model.stages): self.assertEqual(loadedStage.uid, originalStage.uid) + def test_save_load_pipeline_estimator(self): + self._run_test_save_load_pipeline_estimator(LogisticRegression) + self._run_test_save_load_pipeline_estimator(DummyLogisticRegression) + def test_user_specified_folds(self): from pyspark.sql import functions as F @@ -593,7 +613,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): "validationMetrics has the same size of grid parameter") self.assertEqual(1.0, max(validationMetrics)) - def test_save_load_trained_model(self): + def _run_test_save_load_trained_model(self, LogisticRegressionCls, LogisticRegressionModelCls): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() @@ -604,7 +624,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): (Vectors.dense([0.6]), 1.0), (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) - lr = LogisticRegression() + lr = LogisticRegressionCls() grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() evaluator = BinaryClassificationEvaluator() tvs = TrainValidationSplit( @@ -619,7 +639,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): lrModelPath = temp_path + "/lrModel" lrModel.save(lrModelPath) - loadedLrModel = LogisticRegressionModel.load(lrModelPath) + loadedLrModel = LogisticRegressionModelCls.load(lrModelPath) self.assertEqual(loadedLrModel.uid, lrModel.uid) self.assertEqual(loadedLrModel.intercept, lrModel.intercept) @@ -636,7 +656,12 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): loadedTvsModel.isSet(param) for param in loadedTvsModel.params )) - def test_save_load_simple_estimator(self): + def test_save_load_trained_model(self): + self._run_test_save_load_trained_model(LogisticRegression, LogisticRegressionModel) + self._run_test_save_load_trained_model(DummyLogisticRegression, + DummyLogisticRegressionModel) + + def _run_test_save_load_simple_estimator(self, LogisticRegressionCls, evaluatorCls): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() @@ -647,9 +672,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): (Vectors.dense([0.6]), 1.0), (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) - lr = LogisticRegression() + lr = LogisticRegressionCls() grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() + evaluator = evaluatorCls() tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) tvsModel = tvs.fit(dataset) @@ -666,6 +691,12 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): loadedModel = TrainValidationSplitModel.load(tvsModelPath) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + def test_save_load_simple_estimator(self): + self._run_test_save_load_simple_estimator( + LogisticRegression, BinaryClassificationEvaluator) + self._run_test_save_load_simple_estimator( + DummyLogisticRegression, DummyEvaluator) + def test_parallel_evaluation(self): dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), @@ -718,7 +749,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): for i in range(len(grid)): self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) - def test_save_load_nested_estimator(self): + def _run_test_save_load_nested_estimator(self, LogisticRegressionCls): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() @@ -729,9 +760,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): (Vectors.dense([0.6]), 1.0), (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) + ova = OneVsRest(classifier=LogisticRegressionCls()) + lr1 = LogisticRegressionCls().setMaxIter(100) + lr2 = LogisticRegressionCls().setMaxIter(150) grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() evaluator = MulticlassClassificationEvaluator() @@ -759,7 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - def test_save_load_pipeline_estimator(self): + def test_save_load_nested_estimator(self): + self._run_test_save_load_nested_estimator(LogisticRegression) + self._run_test_save_load_nested_estimator(DummyLogisticRegression) + + def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls): temp_path = tempfile.mkdtemp() training = self.spark.createDataFrame([ (0, "a b c d e spark", 1.0), @@ -776,9 +811,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(5) - lr2 = LogisticRegression().setMaxIter(10) + ova = OneVsRest(classifier=LogisticRegressionCls()) + lr1 = LogisticRegressionCls().setMaxIter(5) + lr2 = LogisticRegressionCls().setMaxIter(10) pipeline = Pipeline(stages=[tokenizer, hashingTF, ova]) @@ -836,6 +871,10 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin): original_nested_pipeline_model.stages): self.assertEqual(loadedStage.uid, originalStage.uid) + def test_save_load_pipeline_estimator(self): + self._run_test_save_load_pipeline_estimator(LogisticRegression) + self._run_test_save_load_pipeline_estimator(DummyLogisticRegression) + def test_copy(self): dataset = self.spark.createDataFrame([ (10, 10.0), diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 2b5a9857b0..2c083182de 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,6 +15,7 @@ # limitations under the License. # +import os import sys import itertools from multiprocessing.pool import ThreadPool @@ -22,12 +23,13 @@ from multiprocessing.pool import ThreadPool import numpy as np from pyspark import keyword_only, since, SparkContext -from pyspark.ml import Estimator, Model -from pyspark.ml.common import _py2java, _java2py +from pyspark.ml import Estimator, Transformer, Model +from pyspark.ml.common import inherit_doc, _py2java, _java2py +from pyspark.ml.evaluation import Evaluator from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed -from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \ - MetaAlgorithmReadWrite +from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, \ + MLReadable, MLReader, MLWritable, MLWriter, JavaMLReader, JavaMLWriter from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper from pyspark.sql.functions import col, lit, rand, UserDefinedFunction from pyspark.sql.types import BooleanType @@ -229,6 +231,7 @@ class _ValidatorParams(HasSeed): class _ValidatorSharedReadWrite: + @staticmethod def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps): pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) @@ -275,10 +278,8 @@ class _ValidatorSharedReadWrite: raise ValueError('Resolve param in estimatorParamMaps failed: ' + javaParam.parent() + '.' + javaParam.name()) javaValue = javaPair.value() - if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue): - # Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class - # and Estimator/Transformer class which implements `_from_java` static method - # (such as OneVsRest, Pipeline class). + if sc._jvm.Class.forName("org.apache.spark.ml.util.DefaultParamsWritable") \ + .isInstance(javaValue): pyValue = JavaParams._from_java(javaValue) else: pyValue = _java2py(sc, javaValue) @@ -286,6 +287,222 @@ class _ValidatorSharedReadWrite: pyParamMaps.append(pyParamMap) return pyParamMaps + @staticmethod + def is_java_convertible(instance): + allNestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance.getEstimator()) + evaluator_convertible = isinstance(instance.getEvaluator(), JavaParams) + estimator_convertible = all(map(lambda stage: hasattr(stage, '_to_java'), allNestedStages)) + return estimator_convertible and evaluator_convertible + + @staticmethod + def saveImpl(path, instance, sc, extraMetadata=None): + numParamsNotJson = 0 + jsonEstimatorParamMaps = [] + for paramMap in instance.getEstimatorParamMaps(): + jsonParamMap = [] + for p, v in paramMap.items(): + jsonParam = {'parent': p.parent, 'name': p.name} + if (isinstance(v, Estimator) and not MetaAlgorithmReadWrite.isMetaEstimator(v)) \ + or isinstance(v, Transformer) or isinstance(v, Evaluator): + relative_path = f'epm_{p.name}{numParamsNotJson}' + param_path = os.path.join(path, relative_path) + numParamsNotJson += 1 + v.save(param_path) + jsonParam['value'] = relative_path + jsonParam['isJson'] = False + elif isinstance(v, MLWritable): + raise RuntimeError( + "ValidatorSharedReadWrite.saveImpl does not handle parameters of type: " + "MLWritable that are not Estimaor/Evaluator/Transformer, and if parameter " + "is estimator, it cannot be meta estimator such as Validator or OneVsRest") + else: + jsonParam['value'] = v + jsonParam['isJson'] = True + jsonParamMap.append(jsonParam) + jsonEstimatorParamMaps.append(jsonParamMap) + + skipParams = ['estimator', 'evaluator', 'estimatorParamMaps'] + jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams) + jsonParams['estimatorParamMaps'] = jsonEstimatorParamMaps + + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, jsonParams) + evaluatorPath = os.path.join(path, 'evaluator') + instance.getEvaluator().save(evaluatorPath) + estimatorPath = os.path.join(path, 'estimator') + instance.getEstimator().save(estimatorPath) + + @staticmethod + def load(path, sc, metadata): + evaluatorPath = os.path.join(path, 'evaluator') + evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc) + estimatorPath = os.path.join(path, 'estimator') + estimator = DefaultParamsReader.loadParamsInstance(estimatorPath, sc) + + uidToParams = MetaAlgorithmReadWrite.getUidMap(estimator) + uidToParams[evaluator.uid] = evaluator + + jsonEstimatorParamMaps = metadata['paramMap']['estimatorParamMaps'] + + estimatorParamMaps = [] + for jsonParamMap in jsonEstimatorParamMaps: + paramMap = {} + for jsonParam in jsonParamMap: + est = uidToParams[jsonParam['parent']] + param = getattr(est, jsonParam['name']) + if 'isJson' not in jsonParam or ('isJson' in jsonParam and jsonParam['isJson']): + value = jsonParam['value'] + else: + relativePath = jsonParam['value'] + valueSavedPath = os.path.join(path, relativePath) + value = DefaultParamsReader.loadParamsInstance(valueSavedPath, sc) + paramMap[param] = value + estimatorParamMaps.append(paramMap) + + return metadata, estimator, evaluator, estimatorParamMaps + + @staticmethod + def validateParams(instance): + estiamtor = instance.getEstimator() + evaluator = instance.getEvaluator() + uidMap = MetaAlgorithmReadWrite.getUidMap(estiamtor) + + for elem in [evaluator] + list(uidMap.values()): + if not isinstance(elem, MLWritable): + raise ValueError(f'Validator write will fail because it contains {elem.uid} ' + f'which is not writable.') + + estimatorParamMaps = instance.getEstimatorParamMaps() + paramErr = 'Validator save requires all Params in estimatorParamMaps to apply to ' \ + f'its Estimator, An extraneous Param was found: ' + for paramMap in estimatorParamMaps: + for param in paramMap: + if param.parent not in uidMap: + raise ValueError(paramErr + repr(param)) + + @staticmethod + def getValidatorModelWriterPersistSubModelsParam(writer): + if 'persistsubmodels' in writer.optionMap: + persistSubModelsParam = writer.optionMap['persistsubmodels'].lower() + if persistSubModelsParam == 'true': + return True + elif persistSubModelsParam == 'false': + return False + else: + raise ValueError( + f'persistSubModels option value {persistSubModelsParam} is invalid, ' + f"the possible values are True, 'True' or False, 'False'") + else: + return writer.instance.subModels is not None + + +_save_with_persist_submodels_no_submodels_found_err = \ + 'When persisting tuning models, you can only set persistSubModels to true if the tuning ' \ + 'was done with collectSubModels set to true. To save the sub-models, try rerunning fitting ' \ + 'with collectSubModels set to true.' + + +@inherit_doc +class CrossValidatorReader(MLReader): + + def __init__(self, cls): + super(CrossValidatorReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if not DefaultParamsReader.isPythonParamsInstance(metadata): + return JavaMLReader(self.cls).load(path) + else: + metadata, estimator, evaluator, estimatorParamMaps = \ + _ValidatorSharedReadWrite.load(path, self.sc, metadata) + cv = CrossValidator(estimator=estimator, + estimatorParamMaps=estimatorParamMaps, + evaluator=evaluator) + cv = cv._resetUid(metadata['uid']) + DefaultParamsReader.getAndSetParams(cv, metadata, skipParams=['estimatorParamMaps']) + return cv + + +@inherit_doc +class CrossValidatorWriter(MLWriter): + + def __init__(self, instance): + super(CrossValidatorWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + _ValidatorSharedReadWrite.validateParams(self.instance) + _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc) + + +@inherit_doc +class CrossValidatorModelReader(MLReader): + + def __init__(self, cls): + super(CrossValidatorModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if not DefaultParamsReader.isPythonParamsInstance(metadata): + return JavaMLReader(self.cls).load(path) + else: + metadata, estimator, evaluator, estimatorParamMaps = \ + _ValidatorSharedReadWrite.load(path, self.sc, metadata) + numFolds = metadata['paramMap']['numFolds'] + bestModelPath = os.path.join(path, 'bestModel') + bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc) + avgMetrics = metadata['avgMetrics'] + persistSubModels = ('persistSubModels' in metadata) and metadata['persistSubModels'] + + if persistSubModels: + subModels = [[None] * len(estimatorParamMaps)] * numFolds + for splitIndex in range(numFolds): + for paramIndex in range(len(estimatorParamMaps)): + modelPath = os.path.join( + path, 'subModels', f'fold{splitIndex}', f'{paramIndex}') + subModels[splitIndex][paramIndex] = \ + DefaultParamsReader.loadParamsInstance(modelPath, self.sc) + else: + subModels = None + + cvModel = CrossValidatorModel(bestModel, avgMetrics=avgMetrics, subModels=subModels) + cvModel = cvModel._resetUid(metadata['uid']) + cvModel.set(cvModel.estimator, estimator) + cvModel.set(cvModel.estimatorParamMaps, estimatorParamMaps) + cvModel.set(cvModel.evaluator, evaluator) + DefaultParamsReader.getAndSetParams( + cvModel, metadata, skipParams=['estimatorParamMaps']) + return cvModel + + +@inherit_doc +class CrossValidatorModelWriter(MLWriter): + + def __init__(self, instance): + super(CrossValidatorModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + _ValidatorSharedReadWrite.validateParams(self.instance) + instance = self.instance + persistSubModels = _ValidatorSharedReadWrite \ + .getValidatorModelWriterPersistSubModelsParam(self) + extraMetadata = {'avgMetrics': instance.avgMetrics, + 'persistSubModels': persistSubModels} + _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata) + bestModelPath = os.path.join(path, 'bestModel') + instance.bestModel.save(bestModelPath) + if persistSubModels: + if instance.subModels is None: + raise ValueError(_save_with_persist_submodels_no_submodels_found_err) + subModelsPath = os.path.join(path, 'subModels') + for splitIndex in range(instance.getNumFolds()): + splitPath = os.path.join(subModelsPath, f'fold{splitIndex}') + for paramIndex in range(len(instance.getEstimatorParamMaps())): + modelPath = os.path.join(splitPath, f'{paramIndex}') + instance.subModels[splitIndex][paramIndex].save(modelPath) + class _CrossValidatorParams(_ValidatorParams): """ @@ -553,13 +770,15 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + if _ValidatorSharedReadWrite.is_java_convertible(self): + return JavaMLWriter(self) + return CrossValidatorWriter(self) @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return CrossValidatorReader(cls) @classmethod def _from_java(cls, java_stage): @@ -662,13 +881,15 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable): @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + if _ValidatorSharedReadWrite.is_java_convertible(self): + return JavaMLWriter(self) + return CrossValidatorModelWriter(self) @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return CrossValidatorModelReader(cls) @classmethod def _from_java(cls, java_stage): @@ -738,6 +959,106 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable): return _java_obj +@inherit_doc +class TrainValidationSplitReader(MLReader): + + def __init__(self, cls): + super(TrainValidationSplitReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if not DefaultParamsReader.isPythonParamsInstance(metadata): + return JavaMLReader(self.cls).load(path) + else: + metadata, estimator, evaluator, estimatorParamMaps = \ + _ValidatorSharedReadWrite.load(path, self.sc, metadata) + tvs = TrainValidationSplit(estimator=estimator, + estimatorParamMaps=estimatorParamMaps, + evaluator=evaluator) + tvs = tvs._resetUid(metadata['uid']) + DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams=['estimatorParamMaps']) + return tvs + + +@inherit_doc +class TrainValidationSplitWriter(MLWriter): + + def __init__(self, instance): + super(TrainValidationSplitWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + _ValidatorSharedReadWrite.validateParams(self.instance) + _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc) + + +@inherit_doc +class TrainValidationSplitModelReader(MLReader): + + def __init__(self, cls): + super(TrainValidationSplitModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if not DefaultParamsReader.isPythonParamsInstance(metadata): + return JavaMLReader(self.cls).load(path) + else: + metadata, estimator, evaluator, estimatorParamMaps = \ + _ValidatorSharedReadWrite.load(path, self.sc, metadata) + bestModelPath = os.path.join(path, 'bestModel') + bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc) + validationMetrics = metadata['validationMetrics'] + persistSubModels = ('persistSubModels' in metadata) and metadata['persistSubModels'] + + if persistSubModels: + subModels = [None] * len(estimatorParamMaps) + for paramIndex in range(len(estimatorParamMaps)): + modelPath = os.path.join(path, 'subModels', f'{paramIndex}') + subModels[paramIndex] = \ + DefaultParamsReader.loadParamsInstance(modelPath, self.sc) + else: + subModels = None + + tvsModel = TrainValidationSplitModel( + bestModel, validationMetrics=validationMetrics, subModels=subModels) + tvsModel = tvsModel._resetUid(metadata['uid']) + tvsModel.set(tvsModel.estimator, estimator) + tvsModel.set(tvsModel.estimatorParamMaps, estimatorParamMaps) + tvsModel.set(tvsModel.evaluator, evaluator) + DefaultParamsReader.getAndSetParams( + tvsModel, metadata, skipParams=['estimatorParamMaps']) + return tvsModel + + +@inherit_doc +class TrainValidationSplitModelWriter(MLWriter): + + def __init__(self, instance): + super(TrainValidationSplitModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + _ValidatorSharedReadWrite.validateParams(self.instance) + instance = self.instance + persistSubModels = _ValidatorSharedReadWrite \ + .getValidatorModelWriterPersistSubModelsParam(self) + + extraMetadata = {'validationMetrics': instance.validationMetrics, + 'persistSubModels': persistSubModels} + _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata) + bestModelPath = os.path.join(path, 'bestModel') + instance.bestModel.save(bestModelPath) + if persistSubModels: + if instance.subModels is None: + raise ValueError(_save_with_persist_submodels_no_submodels_found_err) + subModelsPath = os.path.join(path, 'subModels') + for paramIndex in range(len(instance.getEstimatorParamMaps())): + modelPath = os.path.join(subModelsPath, f'{paramIndex}') + instance.subModels[paramIndex].save(modelPath) + + class _TrainValidationSplitParams(_ValidatorParams): """ Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`. @@ -942,13 +1263,15 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + if _ValidatorSharedReadWrite.is_java_convertible(self): + return JavaMLWriter(self) + return TrainValidationSplitWriter(self) @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return TrainValidationSplitReader(cls) @classmethod def _from_java(cls, java_stage): @@ -1046,13 +1369,15 @@ class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + if _ValidatorSharedReadWrite.is_java_convertible(self): + return JavaMLWriter(self) + return TrainValidationSplitModelWriter(self) @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return TrainValidationSplitModelReader(cls) @classmethod def _from_java(cls, java_stage): diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi index 63cd75f0e1..e5f153d49e 100644 --- a/python/pyspark/ml/tuning.pyi +++ b/python/pyspark/ml/tuning.pyi @@ -183,3 +183,43 @@ class TrainValidationSplitModel( def write(self) -> MLWriter: ... @classmethod def read(cls: Type[TrainValidationSplitModel]) -> MLReader: ... + +class CrossValidatorWriter(MLWriter): + instance: CrossValidator + def __init__(self, instance: CrossValidator) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class CrossValidatorReader(MLReader[CrossValidator]): + cls: Type[CrossValidator] + def __init__(self, cls: Type[CrossValidator]) -> None: ... + def load(self, path: str) -> CrossValidator: ... + +class CrossValidatorModelWriter(MLWriter): + instance: CrossValidatorModel + def __init__(self, instance: CrossValidatorModel) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class CrossValidatorModelReader(MLReader[CrossValidatorModel]): + cls: Type[CrossValidatorModel] + def __init__(self, cls: Type[CrossValidatorModel]) -> None: ... + def load(self, path: str) -> CrossValidatorModel: ... + +class TrainValidationSplitWriter(MLWriter): + instance: TrainValidationSplit + def __init__(self, instance: TrainValidationSplit) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class TrainValidationSplitReader(MLReader[TrainValidationSplit]): + cls: Type[TrainValidationSplit] + def __init__(self, cls: Type[TrainValidationSplit]) -> None: ... + def load(self, path: str) -> TrainValidationSplit: ... + +class TrainValidationSplitModelWriter(MLWriter): + instance: TrainValidationSplitModel + def __init__(self, instance: TrainValidationSplitModel) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class TrainValidationSplitModelReader(MLReader[TrainValidationSplitModel]): + cls: Type[TrainValidationSplitModel] + def __init__(self, cls: Type[TrainValidationSplitModel]) -> None: ... + def load(self, path: str) -> TrainValidationSplitModel: ... diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a34bfb5348..156e7f0fe6 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -106,6 +106,7 @@ class MLWriter(BaseReadWrite): def __init__(self): super(MLWriter, self).__init__() self.shouldOverwrite = False + self.optionMap = {} def _handleOverwrite(self, path): from pyspark.ml.wrapper import JavaWrapper @@ -132,6 +133,14 @@ class MLWriter(BaseReadWrite): self.shouldOverwrite = True return self + def option(self, key, value): + """ + Adds an option to the underlying MLWriter. See the documentation for the specific model's + writer for possible options. The option name (key) is case-insensitive. + """ + self.optionMap[key.lower()] = str(value) + return self + @inherit_doc class GeneralMLWriter(MLWriter): @@ -375,6 +384,13 @@ class DefaultParamsWriter(MLWriter): def saveImpl(self, path): DefaultParamsWriter.saveMetadata(self.instance, path, self.sc) + @staticmethod + def extractJsonParams(instance, skipParams): + paramMap = instance.extractParamMap() + jsonParams = {param.name: value for param, value in paramMap.items() + if param.name not in skipParams} + return jsonParams + @staticmethod def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): """ @@ -530,15 +546,16 @@ class DefaultParamsReader(MLReader): return metadata @staticmethod - def getAndSetParams(instance, metadata): + def getAndSetParams(instance, metadata, skipParams=None): """ Extract Params from metadata, and set them in the instance. """ # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) - paramValue = metadata['paramMap'][paramName] - instance.set(param, paramValue) + if skipParams is None or paramName not in skipParams: + paramValue = metadata['paramMap'][paramName] + instance.set(param, paramValue) # Set default param values majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) @@ -554,6 +571,10 @@ class DefaultParamsReader(MLReader): paramValue = metadata['defaultParamMap'][paramName] instance._setDefault(**{paramName: paramValue}) + @staticmethod + def isPythonParamsInstance(metadata): + return metadata['class'].startswith('pyspark.ml.') + @staticmethod def loadParamsInstance(path, sc): """ @@ -561,7 +582,10 @@ class DefaultParamsReader(MLReader): This assumes the instance inherits from :py:class:`MLReadable`. """ metadata = DefaultParamsReader.loadMetadata(path, sc) - pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark") + if DefaultParamsReader.isPythonParamsInstance(metadata): + pythonClassName = metadata['class'] + else: + pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark") py_type = DefaultParamsReader.__get_class(pythonClassName) instance = py_type.load(path) return instance @@ -630,3 +654,13 @@ class MetaAlgorithmReadWrite: nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage)) return [pyInstance] + nestedStages + + @staticmethod + def getUidMap(instance): + nestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance) + uidMap = {stage.uid: stage for stage in nestedStages} + if len(nestedStages) != len(uidMap): + raise RuntimeError(f'{instance.__class__.__module__}.{instance.__class__.__name__}' + f'.load found a compound estimator with stages with duplicate ' + f'UIDs. List of UIDs: {list(uidMap.keys())}.') + return uidMap diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi index e2496e181f..db28c095a5 100644 --- a/python/pyspark/ml/util.pyi +++ b/python/pyspark/ml/util.pyi @@ -132,3 +132,5 @@ class MetaAlgorithmReadWrite: def isMetaEstimator(pyInstance: Any) -> bool: ... @staticmethod def getAllNestedStages(pyInstance: Any) -> list: ... + @staticmethod + def getUidMap(instance: Any) -> dict: ... diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py index a90a64e747..d6edf9d64a 100644 --- a/python/pyspark/testing/mlutils.py +++ b/python/pyspark/testing/mlutils.py @@ -17,8 +17,12 @@ import numpy as np +from pyspark import keyword_only from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer +from pyspark.ml.evaluation import Evaluator from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.param.shared import HasMaxIter, HasRegParam +from pyspark.ml.classification import Classifier, ClassificationModel from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.ml.wrapper import _java2py # type: ignore from pyspark.sql import DataFrame, SparkSession @@ -161,3 +165,86 @@ class MockEstimator(Estimator, HasFake): class MockModel(MockTransformer, Model, HasFake): pass + + +class _DummyLogisticRegressionParams(HasMaxIter, HasRegParam): + def setMaxIter(self, value): + return self._set(maxIter=value) + + def setRegParam(self, value): + return self._set(regParam=value) + + +# This is a dummy LogisticRegression used in test for python backend estimator/model +class DummyLogisticRegression(Classifier, _DummyLogisticRegressionParams, + DefaultParamsReadable, DefaultParamsWritable): + @keyword_only + def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"): + super(DummyLogisticRegression, self).__init__() + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"): + kwargs = self._input_kwargs + self._set(**kwargs) + return self + + def _fit(self, dataset): + # Do nothing but create a dummy model + return self._copyValues(DummyLogisticRegressionModel()) + + +class DummyLogisticRegressionModel(ClassificationModel, _DummyLogisticRegressionParams, + DefaultParamsReadable, DefaultParamsWritable): + + def __init__(self): + super(DummyLogisticRegressionModel, self).__init__() + + def _transform(self, dataset): + # A dummy transform impl which always predict label 1 + from pyspark.sql.functions import array, lit + from pyspark.ml.functions import array_to_vector + rawPredCol = self.getRawPredictionCol() + if rawPredCol: + dataset = dataset.withColumn( + rawPredCol, array_to_vector(array(lit(-100.0), lit(100.0)))) + predCol = self.getPredictionCol() + if predCol: + dataset = dataset.withColumn(predCol, lit(1.0)) + + return dataset + + @property + def numClasses(self): + # a dummy implementation for test. + return 2 + + @property + def intercept(self): + # a dummy implementation for test. + return 0.0 + + # This class only used in test. The following methods/properties are not used in tests. + + @property + def coefficients(self): + raise NotImplementedError() + + def predictRaw(self, value): + raise NotImplementedError() + + def numFeatures(self): + raise NotImplementedError() + + def predict(self, value): + raise NotImplementedError() + + +class DummyEvaluator(Evaluator, DefaultParamsReadable, DefaultParamsWritable): + + def _evaluate(self, dataset): + # a dummy implementation for test. + return 1.0