[SPARK-33520][ML][PYSPARK] make CrossValidator/TrainValidateSplit/OneVsRest Reader/Writer support Python backend estimator/evaluator

### What changes were proposed in this pull request?
make CrossValidator/TrainValidateSplit/OneVsRest Reader/Writer support Python backend estimator/model

### Why are the changes needed?
Currently, pyspark support third-party library to define python backend estimator/evaluator, i.e., estimator that inherit `Estimator` instead of `JavaEstimator`, and only can be used in pyspark.

CrossValidator and TrainValidateSplit support tuning these python backend estimator,
but cannot support saving/load, becase CrossValidator and TrainValidateSplit writer implementation is use JavaMLWriter, which require to convert nested estimator and evaluator into java instance.

OneVsRest saving/load now only support java backend classifier due to similar issue.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unit test.

Closes #30471 from WeichenXu123/support_pyio_tuning.

Authored-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
Weichen Xu 2020-12-04 08:35:50 +08:00
parent 63f9d474b9
commit 7e759b2d95
9 changed files with 739 additions and 59 deletions

View file

@ -15,6 +15,7 @@
# limitations under the License.
#
import os
import operator
import sys
import uuid
@ -33,7 +34,9 @@ from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
_HasVarianceImpurity, _TreeClassifierParams
from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
from pyspark.ml.base import _PredictorParams
from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
JavaMLReadable, JavaMLReader, JavaMLWritable, JavaMLWriter, \
MLReader, MLReadable, MLWriter, MLWritable, HasTrainingSummary
from pyspark.ml.wrapper import JavaParams, \
JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
@ -2760,7 +2763,7 @@ class _OneVsRestParams(_ClassifierParams, HasWeightCol):
@inherit_doc
class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, JavaMLWritable):
class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWritable):
"""
Reduction of Multiclass Classification to Binary Classification.
Performs reduction using one against all strategy.
@ -2991,8 +2994,73 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
return _java_obj
@classmethod
def read(cls):
return OneVsRestReader(cls)
class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
def write(self):
if isinstance(self.getClassifier(), JavaMLWritable):
return JavaMLWriter(self)
else:
return OneVsRestWriter(self)
class _OneVsRestSharedReadWrite:
@staticmethod
def saveImpl(instance, sc, path, extraMetadata=None):
skipParams = ['classifier']
jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams,
extraMetadata=extraMetadata)
classifierPath = os.path.join(path, 'classifier')
instance.getClassifier().save(classifierPath)
@staticmethod
def loadClassifier(path, sc):
classifierPath = os.path.join(path, 'classifier')
return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
@staticmethod
def validateParams(instance):
elems_to_check = [instance.getClassifier()]
if isinstance(instance, OneVsRestModel):
elems_to_check.extend(instance.models)
for elem in elems_to_check:
if not isinstance(elem, MLWritable):
raise ValueError(f'OneVsRest write will fail because it contains {elem.uid} '
f'which is not writable.')
@inherit_doc
class OneVsRestReader(MLReader):
def __init__(self, cls):
super(OneVsRestReader, self).__init__()
self.cls = cls
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path)
else:
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid'])
DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=['classifier'])
return ova
@inherit_doc
class OneVsRestWriter(MLWriter):
def __init__(self, instance):
super(OneVsRestWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
_OneVsRestSharedReadWrite.validateParams(self.instance)
_OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
"""
Model fitted by OneVsRest.
This stores the models resulting from training k binary classifiers: one for each class.
@ -3023,6 +3091,9 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
def __init__(self, models):
super(OneVsRestModel, self).__init__()
self.models = models
if not isinstance(models[0], JavaMLWritable):
return
# set java instance
java_models = [model._to_java() for model in self.models]
sc = SparkContext._active_spark_context
java_models_array = JavaWrapper._new_java_array(java_models,
@ -3160,6 +3231,57 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
_java_obj.set("weightCol", self.getWeightCol())
return _java_obj
@classmethod
def read(cls):
return OneVsRestModelReader(cls)
def write(self):
if all(map(lambda elem: isinstance(elem, JavaMLWritable),
[self.getClassifier()] + self.models)):
return JavaMLWriter(self)
else:
return OneVsRestModelWriter(self)
@inherit_doc
class OneVsRestModelReader(MLReader):
def __init__(self, cls):
super(OneVsRestModelReader, self).__init__()
self.cls = cls
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path)
else:
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
numClasses = metadata['numClasses']
subModels = [None] * numClasses
for idx in range(numClasses):
subModelPath = os.path.join(path, f'model_{idx}')
subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
ovaModel = OneVsRestModel(subModels)._resetUid(metadata['uid'])
ovaModel.set(ovaModel.classifier, classifier)
DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=['classifier'])
return ovaModel
@inherit_doc
class OneVsRestModelWriter(MLWriter):
def __init__(self, instance):
super(OneVsRestModelWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
_OneVsRestSharedReadWrite.validateParams(self.instance)
instance = self.instance
numClasses = len(instance.models)
extraMetadata = {'numClasses': numClasses}
_OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
for idx in range(numClasses):
subModelPath = os.path.join(path, f'model_{idx}')
instance.models[idx].save(subModelPath)
@inherit_doc
class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,

View file

@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, List, Optional
from typing import Any, List, Optional, Type
from pyspark.ml._typing import JM, M, P, T, ParamMap
import abc
@ -53,7 +53,8 @@ from pyspark.ml.tree import (
_TreeClassifierParams,
_TreeEnsembleModel,
)
from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable
from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable, \
MLReader, MLReadable, MLWriter, MLWritable
from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaWrapper
from pyspark.ml.linalg import Matrix, Vector
@ -797,8 +798,8 @@ class OneVsRest(
Estimator[OneVsRestModel],
_OneVsRestParams,
HasParallelism,
JavaMLReadable[OneVsRest],
JavaMLWritable,
MLReadable[OneVsRest],
MLWritable,
):
def __init__(
self,
@ -832,7 +833,7 @@ class OneVsRest(
def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRest: ...
class OneVsRestModel(
Model, _OneVsRestParams, JavaMLReadable[OneVsRestModel], JavaMLWritable
Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable
):
models: List[Transformer]
def __init__(self, models: List[Transformer]) -> None: ...
@ -841,6 +842,26 @@ class OneVsRestModel(
def setRawPredictionCol(self, value: str) -> OneVsRestModel: ...
def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRestModel: ...
class OneVsRestWriter(MLWriter):
instance: OneVsRest
def __init__(self, instance: OneVsRest) -> None: ...
def saveImpl(self, path: str) -> None: ...
class OneVsRestReader(MLReader[OneVsRest]):
cls: Type[OneVsRest]
def __init__(self, cls: Type[OneVsRest]) -> None: ...
def load(self, path: str) -> OneVsRest: ...
class OneVsRestModelWriter(MLWriter):
instance: OneVsRestModel
def __init__(self, instance: OneVsRestModel) -> None: ...
def saveImpl(self, path: str) -> None: ...
class OneVsRestModelReader(MLReader[OneVsRestModel]):
cls: Type[OneVsRestModel]
def __init__(self, cls: Type[OneVsRestModel]) -> None: ...
def load(self, path: str) -> OneVsRestModel: ...
class FMClassifier(
_JavaProbabilisticClassifier[FMClassificationModel],
_FactorizationMachinesParams,

View file

@ -237,6 +237,11 @@ class PersistenceTest(SparkSessionTestCase):
self.assertEqual(len(m1.models), len(m2.models))
for x, y in zip(m1.models, m2.models):
self._compare_pipelines(x, y)
elif isinstance(m1, Params):
# Test on python backend Estimator/Transformer/Model/Evaluator
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self._compare_params(m1, m2, p)
else:
raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
@ -326,14 +331,14 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
def test_onevsrest(self):
def _run_test_onevsrest(self, LogisticRegressionCls):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, 0.5, Vectors.dense(1.0, 0.8)),
(1.0, 0.5, Vectors.sparse(2, [], [])),
(2.0, 1.0, Vectors.dense(0.5, 0.5))] * 10,
["label", "wt", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
lr = LogisticRegressionCls(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
def reload_and_compare(ovr, suffix):
@ -350,6 +355,11 @@ class PersistenceTest(SparkSessionTestCase):
reload_and_compare(OneVsRest(classifier=lr), "ovr")
reload_and_compare(OneVsRest(classifier=lr).setWeightCol("wt"), "ovrw")
def test_onevsrest(self):
from pyspark.testing.mlutils import DummyLogisticRegression
self._run_test_onevsrest(LogisticRegression)
self._run_test_onevsrest(DummyLogisticRegression)
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()

View file

@ -28,7 +28,8 @@ from pyspark.ml.param import Param, Params
from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \
TrainValidationSplit, TrainValidationSplitModel
from pyspark.sql.functions import rand
from pyspark.testing.mlutils import SparkSessionTestCase
from pyspark.testing.mlutils import DummyEvaluator, DummyLogisticRegression, \
DummyLogisticRegressionModel, SparkSessionTestCase
class HasInducedError(Params):
@ -201,7 +202,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
for v in param.values():
assert(type(v) == float)
def test_save_load_trained_model(self):
def _run_test_save_load_trained_model(self, LogisticRegressionCls, LogisticRegressionModelCls):
# This tests saving and loading the trained model only.
# Save/load for CrossValidator will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
@ -212,7 +213,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
lr = LogisticRegressionCls()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(
@ -228,7 +229,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
lrModelPath = temp_path + "/lrModel"
lrModel.save(lrModelPath)
loadedLrModel = LogisticRegressionModel.load(lrModelPath)
loadedLrModel = LogisticRegressionModelCls.load(lrModelPath)
self.assertEqual(loadedLrModel.uid, lrModel.uid)
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
@ -248,7 +249,12 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
loadedCvModel.isSet(param) for param in loadedCvModel.params
))
def test_save_load_simple_estimator(self):
def test_save_load_trained_model(self):
self._run_test_save_load_trained_model(LogisticRegression, LogisticRegressionModel)
self._run_test_save_load_trained_model(DummyLogisticRegression,
DummyLogisticRegressionModel)
def _run_test_save_load_simple_estimator(self, LogisticRegressionCls, evaluatorCls):
temp_path = tempfile.mkdtemp()
dataset = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
@ -258,9 +264,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
lr = LogisticRegressionCls()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
evaluator = evaluatorCls()
# test save/load of CrossValidator
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
@ -278,6 +284,12 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
loadedModel = CrossValidatorModel.load(cvModelPath)
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
def test_save_load_simple_estimator(self):
self._run_test_save_load_simple_estimator(
LogisticRegression, BinaryClassificationEvaluator)
self._run_test_save_load_simple_estimator(
DummyLogisticRegression, DummyEvaluator)
def test_parallel_evaluation(self):
dataset = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
@ -343,7 +355,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
for j in range(len(grid)):
self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid)
def test_save_load_nested_estimator(self):
def _run_test_save_load_nested_estimator(self, LogisticRegressionCls):
temp_path = tempfile.mkdtemp()
dataset = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
@ -353,9 +365,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
ova = OneVsRest(classifier=LogisticRegression())
lr1 = LogisticRegression().setMaxIter(100)
lr2 = LogisticRegression().setMaxIter(150)
ova = OneVsRest(classifier=LogisticRegressionCls())
lr1 = LogisticRegressionCls().setMaxIter(100)
lr2 = LogisticRegressionCls().setMaxIter(150)
grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
evaluator = MulticlassClassificationEvaluator()
@ -385,7 +397,11 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
def test_save_load_pipeline_estimator(self):
def test_save_load_nested_estimator(self):
self._run_test_save_load_nested_estimator(LogisticRegression)
self._run_test_save_load_nested_estimator(DummyLogisticRegression)
def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
temp_path = tempfile.mkdtemp()
training = self.spark.createDataFrame([
(0, "a b c d e spark", 1.0),
@ -402,9 +418,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
ova = OneVsRest(classifier=LogisticRegression())
lr1 = LogisticRegression().setMaxIter(5)
lr2 = LogisticRegression().setMaxIter(10)
ova = OneVsRest(classifier=LogisticRegressionCls())
lr1 = LogisticRegressionCls().setMaxIter(5)
lr2 = LogisticRegressionCls().setMaxIter(10)
pipeline = Pipeline(stages=[tokenizer, hashingTF, ova])
@ -464,6 +480,10 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
original_nested_pipeline_model.stages):
self.assertEqual(loadedStage.uid, originalStage.uid)
def test_save_load_pipeline_estimator(self):
self._run_test_save_load_pipeline_estimator(LogisticRegression)
self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
def test_user_specified_folds(self):
from pyspark.sql import functions as F
@ -593,7 +613,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
"validationMetrics has the same size of grid parameter")
self.assertEqual(1.0, max(validationMetrics))
def test_save_load_trained_model(self):
def _run_test_save_load_trained_model(self, LogisticRegressionCls, LogisticRegressionModelCls):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
@ -604,7 +624,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
lr = LogisticRegressionCls()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
tvs = TrainValidationSplit(
@ -619,7 +639,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
lrModelPath = temp_path + "/lrModel"
lrModel.save(lrModelPath)
loadedLrModel = LogisticRegressionModel.load(lrModelPath)
loadedLrModel = LogisticRegressionModelCls.load(lrModelPath)
self.assertEqual(loadedLrModel.uid, lrModel.uid)
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
@ -636,7 +656,12 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
loadedTvsModel.isSet(param) for param in loadedTvsModel.params
))
def test_save_load_simple_estimator(self):
def test_save_load_trained_model(self):
self._run_test_save_load_trained_model(LogisticRegression, LogisticRegressionModel)
self._run_test_save_load_trained_model(DummyLogisticRegression,
DummyLogisticRegressionModel)
def _run_test_save_load_simple_estimator(self, LogisticRegressionCls, evaluatorCls):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
@ -647,9 +672,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
lr = LogisticRegressionCls()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
evaluator = evaluatorCls()
tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
tvsModel = tvs.fit(dataset)
@ -666,6 +691,12 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
loadedModel = TrainValidationSplitModel.load(tvsModelPath)
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
def test_save_load_simple_estimator(self):
self._run_test_save_load_simple_estimator(
LogisticRegression, BinaryClassificationEvaluator)
self._run_test_save_load_simple_estimator(
DummyLogisticRegression, DummyEvaluator)
def test_parallel_evaluation(self):
dataset = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
@ -718,7 +749,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
for i in range(len(grid)):
self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)
def test_save_load_nested_estimator(self):
def _run_test_save_load_nested_estimator(self, LogisticRegressionCls):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
@ -729,9 +760,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
ova = OneVsRest(classifier=LogisticRegression())
lr1 = LogisticRegression().setMaxIter(100)
lr2 = LogisticRegression().setMaxIter(150)
ova = OneVsRest(classifier=LogisticRegressionCls())
lr1 = LogisticRegressionCls().setMaxIter(100)
lr2 = LogisticRegressionCls().setMaxIter(150)
grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
evaluator = MulticlassClassificationEvaluator()
@ -759,7 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
def test_save_load_pipeline_estimator(self):
def test_save_load_nested_estimator(self):
self._run_test_save_load_nested_estimator(LogisticRegression)
self._run_test_save_load_nested_estimator(DummyLogisticRegression)
def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
temp_path = tempfile.mkdtemp()
training = self.spark.createDataFrame([
(0, "a b c d e spark", 1.0),
@ -776,9 +811,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
ova = OneVsRest(classifier=LogisticRegression())
lr1 = LogisticRegression().setMaxIter(5)
lr2 = LogisticRegression().setMaxIter(10)
ova = OneVsRest(classifier=LogisticRegressionCls())
lr1 = LogisticRegressionCls().setMaxIter(5)
lr2 = LogisticRegressionCls().setMaxIter(10)
pipeline = Pipeline(stages=[tokenizer, hashingTF, ova])
@ -836,6 +871,10 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
original_nested_pipeline_model.stages):
self.assertEqual(loadedStage.uid, originalStage.uid)
def test_save_load_pipeline_estimator(self):
self._run_test_save_load_pipeline_estimator(LogisticRegression)
self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
def test_copy(self):
dataset = self.spark.createDataFrame([
(10, 10.0),

View file

@ -15,6 +15,7 @@
# limitations under the License.
#
import os
import sys
import itertools
from multiprocessing.pool import ThreadPool
@ -22,12 +23,13 @@ from multiprocessing.pool import ThreadPool
import numpy as np
from pyspark import keyword_only, since, SparkContext
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java, _java2py
from pyspark.ml import Estimator, Transformer, Model
from pyspark.ml.common import inherit_doc, _py2java, _java2py
from pyspark.ml.evaluation import Evaluator
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
MetaAlgorithmReadWrite
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, \
MLReadable, MLReader, MLWritable, MLWriter, JavaMLReader, JavaMLWriter
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
from pyspark.sql.types import BooleanType
@ -229,6 +231,7 @@ class _ValidatorParams(HasSeed):
class _ValidatorSharedReadWrite:
@staticmethod
def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps):
pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
@ -275,10 +278,8 @@ class _ValidatorSharedReadWrite:
raise ValueError('Resolve param in estimatorParamMaps failed: ' +
javaParam.parent() + '.' + javaParam.name())
javaValue = javaPair.value()
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue):
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
# and Estimator/Transformer class which implements `_from_java` static method
# (such as OneVsRest, Pipeline class).
if sc._jvm.Class.forName("org.apache.spark.ml.util.DefaultParamsWritable") \
.isInstance(javaValue):
pyValue = JavaParams._from_java(javaValue)
else:
pyValue = _java2py(sc, javaValue)
@ -286,6 +287,222 @@ class _ValidatorSharedReadWrite:
pyParamMaps.append(pyParamMap)
return pyParamMaps
@staticmethod
def is_java_convertible(instance):
allNestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance.getEstimator())
evaluator_convertible = isinstance(instance.getEvaluator(), JavaParams)
estimator_convertible = all(map(lambda stage: hasattr(stage, '_to_java'), allNestedStages))
return estimator_convertible and evaluator_convertible
@staticmethod
def saveImpl(path, instance, sc, extraMetadata=None):
numParamsNotJson = 0
jsonEstimatorParamMaps = []
for paramMap in instance.getEstimatorParamMaps():
jsonParamMap = []
for p, v in paramMap.items():
jsonParam = {'parent': p.parent, 'name': p.name}
if (isinstance(v, Estimator) and not MetaAlgorithmReadWrite.isMetaEstimator(v)) \
or isinstance(v, Transformer) or isinstance(v, Evaluator):
relative_path = f'epm_{p.name}{numParamsNotJson}'
param_path = os.path.join(path, relative_path)
numParamsNotJson += 1
v.save(param_path)
jsonParam['value'] = relative_path
jsonParam['isJson'] = False
elif isinstance(v, MLWritable):
raise RuntimeError(
"ValidatorSharedReadWrite.saveImpl does not handle parameters of type: "
"MLWritable that are not Estimaor/Evaluator/Transformer, and if parameter "
"is estimator, it cannot be meta estimator such as Validator or OneVsRest")
else:
jsonParam['value'] = v
jsonParam['isJson'] = True
jsonParamMap.append(jsonParam)
jsonEstimatorParamMaps.append(jsonParamMap)
skipParams = ['estimator', 'evaluator', 'estimatorParamMaps']
jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
jsonParams['estimatorParamMaps'] = jsonEstimatorParamMaps
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, jsonParams)
evaluatorPath = os.path.join(path, 'evaluator')
instance.getEvaluator().save(evaluatorPath)
estimatorPath = os.path.join(path, 'estimator')
instance.getEstimator().save(estimatorPath)
@staticmethod
def load(path, sc, metadata):
evaluatorPath = os.path.join(path, 'evaluator')
evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc)
estimatorPath = os.path.join(path, 'estimator')
estimator = DefaultParamsReader.loadParamsInstance(estimatorPath, sc)
uidToParams = MetaAlgorithmReadWrite.getUidMap(estimator)
uidToParams[evaluator.uid] = evaluator
jsonEstimatorParamMaps = metadata['paramMap']['estimatorParamMaps']
estimatorParamMaps = []
for jsonParamMap in jsonEstimatorParamMaps:
paramMap = {}
for jsonParam in jsonParamMap:
est = uidToParams[jsonParam['parent']]
param = getattr(est, jsonParam['name'])
if 'isJson' not in jsonParam or ('isJson' in jsonParam and jsonParam['isJson']):
value = jsonParam['value']
else:
relativePath = jsonParam['value']
valueSavedPath = os.path.join(path, relativePath)
value = DefaultParamsReader.loadParamsInstance(valueSavedPath, sc)
paramMap[param] = value
estimatorParamMaps.append(paramMap)
return metadata, estimator, evaluator, estimatorParamMaps
@staticmethod
def validateParams(instance):
estiamtor = instance.getEstimator()
evaluator = instance.getEvaluator()
uidMap = MetaAlgorithmReadWrite.getUidMap(estiamtor)
for elem in [evaluator] + list(uidMap.values()):
if not isinstance(elem, MLWritable):
raise ValueError(f'Validator write will fail because it contains {elem.uid} '
f'which is not writable.')
estimatorParamMaps = instance.getEstimatorParamMaps()
paramErr = 'Validator save requires all Params in estimatorParamMaps to apply to ' \
f'its Estimator, An extraneous Param was found: '
for paramMap in estimatorParamMaps:
for param in paramMap:
if param.parent not in uidMap:
raise ValueError(paramErr + repr(param))
@staticmethod
def getValidatorModelWriterPersistSubModelsParam(writer):
if 'persistsubmodels' in writer.optionMap:
persistSubModelsParam = writer.optionMap['persistsubmodels'].lower()
if persistSubModelsParam == 'true':
return True
elif persistSubModelsParam == 'false':
return False
else:
raise ValueError(
f'persistSubModels option value {persistSubModelsParam} is invalid, '
f"the possible values are True, 'True' or False, 'False'")
else:
return writer.instance.subModels is not None
_save_with_persist_submodels_no_submodels_found_err = \
'When persisting tuning models, you can only set persistSubModels to true if the tuning ' \
'was done with collectSubModels set to true. To save the sub-models, try rerunning fitting ' \
'with collectSubModels set to true.'
@inherit_doc
class CrossValidatorReader(MLReader):
def __init__(self, cls):
super(CrossValidatorReader, self).__init__()
self.cls = cls
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path)
else:
metadata, estimator, evaluator, estimatorParamMaps = \
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
cv = CrossValidator(estimator=estimator,
estimatorParamMaps=estimatorParamMaps,
evaluator=evaluator)
cv = cv._resetUid(metadata['uid'])
DefaultParamsReader.getAndSetParams(cv, metadata, skipParams=['estimatorParamMaps'])
return cv
@inherit_doc
class CrossValidatorWriter(MLWriter):
def __init__(self, instance):
super(CrossValidatorWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
_ValidatorSharedReadWrite.validateParams(self.instance)
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
@inherit_doc
class CrossValidatorModelReader(MLReader):
def __init__(self, cls):
super(CrossValidatorModelReader, self).__init__()
self.cls = cls
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path)
else:
metadata, estimator, evaluator, estimatorParamMaps = \
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
numFolds = metadata['paramMap']['numFolds']
bestModelPath = os.path.join(path, 'bestModel')
bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
avgMetrics = metadata['avgMetrics']
persistSubModels = ('persistSubModels' in metadata) and metadata['persistSubModels']
if persistSubModels:
subModels = [[None] * len(estimatorParamMaps)] * numFolds
for splitIndex in range(numFolds):
for paramIndex in range(len(estimatorParamMaps)):
modelPath = os.path.join(
path, 'subModels', f'fold{splitIndex}', f'{paramIndex}')
subModels[splitIndex][paramIndex] = \
DefaultParamsReader.loadParamsInstance(modelPath, self.sc)
else:
subModels = None
cvModel = CrossValidatorModel(bestModel, avgMetrics=avgMetrics, subModels=subModels)
cvModel = cvModel._resetUid(metadata['uid'])
cvModel.set(cvModel.estimator, estimator)
cvModel.set(cvModel.estimatorParamMaps, estimatorParamMaps)
cvModel.set(cvModel.evaluator, evaluator)
DefaultParamsReader.getAndSetParams(
cvModel, metadata, skipParams=['estimatorParamMaps'])
return cvModel
@inherit_doc
class CrossValidatorModelWriter(MLWriter):
def __init__(self, instance):
super(CrossValidatorModelWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
_ValidatorSharedReadWrite.validateParams(self.instance)
instance = self.instance
persistSubModels = _ValidatorSharedReadWrite \
.getValidatorModelWriterPersistSubModelsParam(self)
extraMetadata = {'avgMetrics': instance.avgMetrics,
'persistSubModels': persistSubModels}
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
bestModelPath = os.path.join(path, 'bestModel')
instance.bestModel.save(bestModelPath)
if persistSubModels:
if instance.subModels is None:
raise ValueError(_save_with_persist_submodels_no_submodels_found_err)
subModelsPath = os.path.join(path, 'subModels')
for splitIndex in range(instance.getNumFolds()):
splitPath = os.path.join(subModelsPath, f'fold{splitIndex}')
for paramIndex in range(len(instance.getEstimatorParamMaps())):
modelPath = os.path.join(splitPath, f'{paramIndex}')
instance.subModels[splitIndex][paramIndex].save(modelPath)
class _CrossValidatorParams(_ValidatorParams):
"""
@ -553,13 +770,15 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec
@since("2.3.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
return JavaMLWriter(self)
return CrossValidatorWriter(self)
@classmethod
@since("2.3.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return CrossValidatorReader(cls)
@classmethod
def _from_java(cls, java_stage):
@ -662,13 +881,15 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
@since("2.3.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
return JavaMLWriter(self)
return CrossValidatorModelWriter(self)
@classmethod
@since("2.3.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return CrossValidatorModelReader(cls)
@classmethod
def _from_java(cls, java_stage):
@ -738,6 +959,106 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
return _java_obj
@inherit_doc
class TrainValidationSplitReader(MLReader):
def __init__(self, cls):
super(TrainValidationSplitReader, self).__init__()
self.cls = cls
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path)
else:
metadata, estimator, evaluator, estimatorParamMaps = \
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
tvs = TrainValidationSplit(estimator=estimator,
estimatorParamMaps=estimatorParamMaps,
evaluator=evaluator)
tvs = tvs._resetUid(metadata['uid'])
DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams=['estimatorParamMaps'])
return tvs
@inherit_doc
class TrainValidationSplitWriter(MLWriter):
def __init__(self, instance):
super(TrainValidationSplitWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
_ValidatorSharedReadWrite.validateParams(self.instance)
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
@inherit_doc
class TrainValidationSplitModelReader(MLReader):
def __init__(self, cls):
super(TrainValidationSplitModelReader, self).__init__()
self.cls = cls
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path)
else:
metadata, estimator, evaluator, estimatorParamMaps = \
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
bestModelPath = os.path.join(path, 'bestModel')
bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
validationMetrics = metadata['validationMetrics']
persistSubModels = ('persistSubModels' in metadata) and metadata['persistSubModels']
if persistSubModels:
subModels = [None] * len(estimatorParamMaps)
for paramIndex in range(len(estimatorParamMaps)):
modelPath = os.path.join(path, 'subModels', f'{paramIndex}')
subModels[paramIndex] = \
DefaultParamsReader.loadParamsInstance(modelPath, self.sc)
else:
subModels = None
tvsModel = TrainValidationSplitModel(
bestModel, validationMetrics=validationMetrics, subModels=subModels)
tvsModel = tvsModel._resetUid(metadata['uid'])
tvsModel.set(tvsModel.estimator, estimator)
tvsModel.set(tvsModel.estimatorParamMaps, estimatorParamMaps)
tvsModel.set(tvsModel.evaluator, evaluator)
DefaultParamsReader.getAndSetParams(
tvsModel, metadata, skipParams=['estimatorParamMaps'])
return tvsModel
@inherit_doc
class TrainValidationSplitModelWriter(MLWriter):
def __init__(self, instance):
super(TrainValidationSplitModelWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
_ValidatorSharedReadWrite.validateParams(self.instance)
instance = self.instance
persistSubModels = _ValidatorSharedReadWrite \
.getValidatorModelWriterPersistSubModelsParam(self)
extraMetadata = {'validationMetrics': instance.validationMetrics,
'persistSubModels': persistSubModels}
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
bestModelPath = os.path.join(path, 'bestModel')
instance.bestModel.save(bestModelPath)
if persistSubModels:
if instance.subModels is None:
raise ValueError(_save_with_persist_submodels_no_submodels_found_err)
subModelsPath = os.path.join(path, 'subModels')
for paramIndex in range(len(instance.getEstimatorParamMaps())):
modelPath = os.path.join(subModelsPath, f'{paramIndex}')
instance.subModels[paramIndex].save(modelPath)
class _TrainValidationSplitParams(_ValidatorParams):
"""
Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.
@ -942,13 +1263,15 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis
@since("2.3.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
return JavaMLWriter(self)
return TrainValidationSplitWriter(self)
@classmethod
@since("2.3.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return TrainValidationSplitReader(cls)
@classmethod
def _from_java(cls, java_stage):
@ -1046,13 +1369,15 @@ class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable,
@since("2.3.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
return JavaMLWriter(self)
return TrainValidationSplitModelWriter(self)
@classmethod
@since("2.3.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return TrainValidationSplitModelReader(cls)
@classmethod
def _from_java(cls, java_stage):

View file

@ -183,3 +183,43 @@ class TrainValidationSplitModel(
def write(self) -> MLWriter: ...
@classmethod
def read(cls: Type[TrainValidationSplitModel]) -> MLReader: ...
class CrossValidatorWriter(MLWriter):
instance: CrossValidator
def __init__(self, instance: CrossValidator) -> None: ...
def saveImpl(self, path: str) -> None: ...
class CrossValidatorReader(MLReader[CrossValidator]):
cls: Type[CrossValidator]
def __init__(self, cls: Type[CrossValidator]) -> None: ...
def load(self, path: str) -> CrossValidator: ...
class CrossValidatorModelWriter(MLWriter):
instance: CrossValidatorModel
def __init__(self, instance: CrossValidatorModel) -> None: ...
def saveImpl(self, path: str) -> None: ...
class CrossValidatorModelReader(MLReader[CrossValidatorModel]):
cls: Type[CrossValidatorModel]
def __init__(self, cls: Type[CrossValidatorModel]) -> None: ...
def load(self, path: str) -> CrossValidatorModel: ...
class TrainValidationSplitWriter(MLWriter):
instance: TrainValidationSplit
def __init__(self, instance: TrainValidationSplit) -> None: ...
def saveImpl(self, path: str) -> None: ...
class TrainValidationSplitReader(MLReader[TrainValidationSplit]):
cls: Type[TrainValidationSplit]
def __init__(self, cls: Type[TrainValidationSplit]) -> None: ...
def load(self, path: str) -> TrainValidationSplit: ...
class TrainValidationSplitModelWriter(MLWriter):
instance: TrainValidationSplitModel
def __init__(self, instance: TrainValidationSplitModel) -> None: ...
def saveImpl(self, path: str) -> None: ...
class TrainValidationSplitModelReader(MLReader[TrainValidationSplitModel]):
cls: Type[TrainValidationSplitModel]
def __init__(self, cls: Type[TrainValidationSplitModel]) -> None: ...
def load(self, path: str) -> TrainValidationSplitModel: ...

View file

@ -106,6 +106,7 @@ class MLWriter(BaseReadWrite):
def __init__(self):
super(MLWriter, self).__init__()
self.shouldOverwrite = False
self.optionMap = {}
def _handleOverwrite(self, path):
from pyspark.ml.wrapper import JavaWrapper
@ -132,6 +133,14 @@ class MLWriter(BaseReadWrite):
self.shouldOverwrite = True
return self
def option(self, key, value):
"""
Adds an option to the underlying MLWriter. See the documentation for the specific model's
writer for possible options. The option name (key) is case-insensitive.
"""
self.optionMap[key.lower()] = str(value)
return self
@inherit_doc
class GeneralMLWriter(MLWriter):
@ -375,6 +384,13 @@ class DefaultParamsWriter(MLWriter):
def saveImpl(self, path):
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
@staticmethod
def extractJsonParams(instance, skipParams):
paramMap = instance.extractParamMap()
jsonParams = {param.name: value for param, value in paramMap.items()
if param.name not in skipParams}
return jsonParams
@staticmethod
def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
"""
@ -530,13 +546,14 @@ class DefaultParamsReader(MLReader):
return metadata
@staticmethod
def getAndSetParams(instance, metadata):
def getAndSetParams(instance, metadata, skipParams=None):
"""
Extract Params from metadata, and set them in the instance.
"""
# Set user-supplied param values
for paramName in metadata['paramMap']:
param = instance.getParam(paramName)
if skipParams is None or paramName not in skipParams:
paramValue = metadata['paramMap'][paramName]
instance.set(param, paramValue)
@ -554,6 +571,10 @@ class DefaultParamsReader(MLReader):
paramValue = metadata['defaultParamMap'][paramName]
instance._setDefault(**{paramName: paramValue})
@staticmethod
def isPythonParamsInstance(metadata):
return metadata['class'].startswith('pyspark.ml.')
@staticmethod
def loadParamsInstance(path, sc):
"""
@ -561,6 +582,9 @@ class DefaultParamsReader(MLReader):
This assumes the instance inherits from :py:class:`MLReadable`.
"""
metadata = DefaultParamsReader.loadMetadata(path, sc)
if DefaultParamsReader.isPythonParamsInstance(metadata):
pythonClassName = metadata['class']
else:
pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
py_type = DefaultParamsReader.__get_class(pythonClassName)
instance = py_type.load(path)
@ -630,3 +654,13 @@ class MetaAlgorithmReadWrite:
nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))
return [pyInstance] + nestedStages
@staticmethod
def getUidMap(instance):
nestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance)
uidMap = {stage.uid: stage for stage in nestedStages}
if len(nestedStages) != len(uidMap):
raise RuntimeError(f'{instance.__class__.__module__}.{instance.__class__.__name__}'
f'.load found a compound estimator with stages with duplicate '
f'UIDs. List of UIDs: {list(uidMap.keys())}.')
return uidMap

View file

@ -132,3 +132,5 @@ class MetaAlgorithmReadWrite:
def isMetaEstimator(pyInstance: Any) -> bool: ...
@staticmethod
def getAllNestedStages(pyInstance: Any) -> list: ...
@staticmethod
def getUidMap(instance: Any) -> dict: ...

View file

@ -17,8 +17,12 @@
import numpy as np
from pyspark import keyword_only
from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.evaluation import Evaluator
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasRegParam
from pyspark.ml.classification import Classifier, ClassificationModel
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.ml.wrapper import _java2py # type: ignore
from pyspark.sql import DataFrame, SparkSession
@ -161,3 +165,86 @@ class MockEstimator(Estimator, HasFake):
class MockModel(MockTransformer, Model, HasFake):
pass
class _DummyLogisticRegressionParams(HasMaxIter, HasRegParam):
def setMaxIter(self, value):
return self._set(maxIter=value)
def setRegParam(self, value):
return self._set(regParam=value)
# This is a dummy LogisticRegression used in test for python backend estimator/model
class DummyLogisticRegression(Classifier, _DummyLogisticRegressionParams,
DefaultParamsReadable, DefaultParamsWritable):
@keyword_only
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"):
super(DummyLogisticRegression, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"):
kwargs = self._input_kwargs
self._set(**kwargs)
return self
def _fit(self, dataset):
# Do nothing but create a dummy model
return self._copyValues(DummyLogisticRegressionModel())
class DummyLogisticRegressionModel(ClassificationModel, _DummyLogisticRegressionParams,
DefaultParamsReadable, DefaultParamsWritable):
def __init__(self):
super(DummyLogisticRegressionModel, self).__init__()
def _transform(self, dataset):
# A dummy transform impl which always predict label 1
from pyspark.sql.functions import array, lit
from pyspark.ml.functions import array_to_vector
rawPredCol = self.getRawPredictionCol()
if rawPredCol:
dataset = dataset.withColumn(
rawPredCol, array_to_vector(array(lit(-100.0), lit(100.0))))
predCol = self.getPredictionCol()
if predCol:
dataset = dataset.withColumn(predCol, lit(1.0))
return dataset
@property
def numClasses(self):
# a dummy implementation for test.
return 2
@property
def intercept(self):
# a dummy implementation for test.
return 0.0
# This class only used in test. The following methods/properties are not used in tests.
@property
def coefficients(self):
raise NotImplementedError()
def predictRaw(self, value):
raise NotImplementedError()
def numFeatures(self):
raise NotImplementedError()
def predict(self, value):
raise NotImplementedError()
class DummyEvaluator(Evaluator, DefaultParamsReadable, DefaultParamsWritable):
def _evaluate(self, dataset):
# a dummy implementation for test.
return 1.0