[SPARK-33520][ML][PYSPARK] make CrossValidator/TrainValidateSplit/OneVsRest Reader/Writer support Python backend estimator/evaluator
### What changes were proposed in this pull request? make CrossValidator/TrainValidateSplit/OneVsRest Reader/Writer support Python backend estimator/model ### Why are the changes needed? Currently, pyspark support third-party library to define python backend estimator/evaluator, i.e., estimator that inherit `Estimator` instead of `JavaEstimator`, and only can be used in pyspark. CrossValidator and TrainValidateSplit support tuning these python backend estimator, but cannot support saving/load, becase CrossValidator and TrainValidateSplit writer implementation is use JavaMLWriter, which require to convert nested estimator and evaluator into java instance. OneVsRest saving/load now only support java backend classifier due to similar issue. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #30471 from WeichenXu123/support_pyio_tuning. Authored-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
63f9d474b9
commit
7e759b2d95
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import operator
|
||||
import sys
|
||||
import uuid
|
||||
|
@ -33,7 +34,9 @@ from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
|
|||
_HasVarianceImpurity, _TreeClassifierParams
|
||||
from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
|
||||
from pyspark.ml.base import _PredictorParams
|
||||
from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary
|
||||
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
|
||||
JavaMLReadable, JavaMLReader, JavaMLWritable, JavaMLWriter, \
|
||||
MLReader, MLReadable, MLWriter, MLWritable, HasTrainingSummary
|
||||
from pyspark.ml.wrapper import JavaParams, \
|
||||
JavaPredictor, JavaPredictionModel, JavaWrapper
|
||||
from pyspark.ml.common import inherit_doc
|
||||
|
@ -2760,7 +2763,7 @@ class _OneVsRestParams(_ClassifierParams, HasWeightCol):
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, JavaMLWritable):
|
||||
class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWritable):
|
||||
"""
|
||||
Reduction of Multiclass Classification to Binary Classification.
|
||||
Performs reduction using one against all strategy.
|
||||
|
@ -2991,8 +2994,73 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
|
|||
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
|
||||
return _java_obj
|
||||
|
||||
@classmethod
|
||||
def read(cls):
|
||||
return OneVsRestReader(cls)
|
||||
|
||||
class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
||||
def write(self):
|
||||
if isinstance(self.getClassifier(), JavaMLWritable):
|
||||
return JavaMLWriter(self)
|
||||
else:
|
||||
return OneVsRestWriter(self)
|
||||
|
||||
|
||||
class _OneVsRestSharedReadWrite:
|
||||
@staticmethod
|
||||
def saveImpl(instance, sc, path, extraMetadata=None):
|
||||
skipParams = ['classifier']
|
||||
jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams,
|
||||
extraMetadata=extraMetadata)
|
||||
classifierPath = os.path.join(path, 'classifier')
|
||||
instance.getClassifier().save(classifierPath)
|
||||
|
||||
@staticmethod
|
||||
def loadClassifier(path, sc):
|
||||
classifierPath = os.path.join(path, 'classifier')
|
||||
return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
|
||||
|
||||
@staticmethod
|
||||
def validateParams(instance):
|
||||
elems_to_check = [instance.getClassifier()]
|
||||
if isinstance(instance, OneVsRestModel):
|
||||
elems_to_check.extend(instance.models)
|
||||
|
||||
for elem in elems_to_check:
|
||||
if not isinstance(elem, MLWritable):
|
||||
raise ValueError(f'OneVsRest write will fail because it contains {elem.uid} '
|
||||
f'which is not writable.')
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class OneVsRestReader(MLReader):
|
||||
def __init__(self, cls):
|
||||
super(OneVsRestReader, self).__init__()
|
||||
self.cls = cls
|
||||
|
||||
def load(self, path):
|
||||
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
||||
if not DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
return JavaMLReader(self.cls).load(path)
|
||||
else:
|
||||
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
|
||||
ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid'])
|
||||
DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=['classifier'])
|
||||
return ova
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class OneVsRestWriter(MLWriter):
|
||||
def __init__(self, instance):
|
||||
super(OneVsRestWriter, self).__init__()
|
||||
self.instance = instance
|
||||
|
||||
def saveImpl(self, path):
|
||||
_OneVsRestSharedReadWrite.validateParams(self.instance)
|
||||
_OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
|
||||
|
||||
|
||||
class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
|
||||
"""
|
||||
Model fitted by OneVsRest.
|
||||
This stores the models resulting from training k binary classifiers: one for each class.
|
||||
|
@ -3023,6 +3091,9 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
|||
def __init__(self, models):
|
||||
super(OneVsRestModel, self).__init__()
|
||||
self.models = models
|
||||
if not isinstance(models[0], JavaMLWritable):
|
||||
return
|
||||
# set java instance
|
||||
java_models = [model._to_java() for model in self.models]
|
||||
sc = SparkContext._active_spark_context
|
||||
java_models_array = JavaWrapper._new_java_array(java_models,
|
||||
|
@ -3160,6 +3231,57 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
|||
_java_obj.set("weightCol", self.getWeightCol())
|
||||
return _java_obj
|
||||
|
||||
@classmethod
|
||||
def read(cls):
|
||||
return OneVsRestModelReader(cls)
|
||||
|
||||
def write(self):
|
||||
if all(map(lambda elem: isinstance(elem, JavaMLWritable),
|
||||
[self.getClassifier()] + self.models)):
|
||||
return JavaMLWriter(self)
|
||||
else:
|
||||
return OneVsRestModelWriter(self)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class OneVsRestModelReader(MLReader):
|
||||
def __init__(self, cls):
|
||||
super(OneVsRestModelReader, self).__init__()
|
||||
self.cls = cls
|
||||
|
||||
def load(self, path):
|
||||
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
||||
if not DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
return JavaMLReader(self.cls).load(path)
|
||||
else:
|
||||
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
|
||||
numClasses = metadata['numClasses']
|
||||
subModels = [None] * numClasses
|
||||
for idx in range(numClasses):
|
||||
subModelPath = os.path.join(path, f'model_{idx}')
|
||||
subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
|
||||
ovaModel = OneVsRestModel(subModels)._resetUid(metadata['uid'])
|
||||
ovaModel.set(ovaModel.classifier, classifier)
|
||||
DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=['classifier'])
|
||||
return ovaModel
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class OneVsRestModelWriter(MLWriter):
|
||||
def __init__(self, instance):
|
||||
super(OneVsRestModelWriter, self).__init__()
|
||||
self.instance = instance
|
||||
|
||||
def saveImpl(self, path):
|
||||
_OneVsRestSharedReadWrite.validateParams(self.instance)
|
||||
instance = self.instance
|
||||
numClasses = len(instance.models)
|
||||
extraMetadata = {'numClasses': numClasses}
|
||||
_OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
|
||||
for idx in range(numClasses):
|
||||
subModelPath = os.path.join(path, f'model_{idx}')
|
||||
instance.models[idx].save(subModelPath)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Type
|
||||
from pyspark.ml._typing import JM, M, P, T, ParamMap
|
||||
|
||||
import abc
|
||||
|
@ -53,7 +53,8 @@ from pyspark.ml.tree import (
|
|||
_TreeClassifierParams,
|
||||
_TreeEnsembleModel,
|
||||
)
|
||||
from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable
|
||||
from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable, \
|
||||
MLReader, MLReadable, MLWriter, MLWritable
|
||||
from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaWrapper
|
||||
|
||||
from pyspark.ml.linalg import Matrix, Vector
|
||||
|
@ -797,8 +798,8 @@ class OneVsRest(
|
|||
Estimator[OneVsRestModel],
|
||||
_OneVsRestParams,
|
||||
HasParallelism,
|
||||
JavaMLReadable[OneVsRest],
|
||||
JavaMLWritable,
|
||||
MLReadable[OneVsRest],
|
||||
MLWritable,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -832,7 +833,7 @@ class OneVsRest(
|
|||
def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRest: ...
|
||||
|
||||
class OneVsRestModel(
|
||||
Model, _OneVsRestParams, JavaMLReadable[OneVsRestModel], JavaMLWritable
|
||||
Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable
|
||||
):
|
||||
models: List[Transformer]
|
||||
def __init__(self, models: List[Transformer]) -> None: ...
|
||||
|
@ -841,6 +842,26 @@ class OneVsRestModel(
|
|||
def setRawPredictionCol(self, value: str) -> OneVsRestModel: ...
|
||||
def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRestModel: ...
|
||||
|
||||
class OneVsRestWriter(MLWriter):
|
||||
instance: OneVsRest
|
||||
def __init__(self, instance: OneVsRest) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class OneVsRestReader(MLReader[OneVsRest]):
|
||||
cls: Type[OneVsRest]
|
||||
def __init__(self, cls: Type[OneVsRest]) -> None: ...
|
||||
def load(self, path: str) -> OneVsRest: ...
|
||||
|
||||
class OneVsRestModelWriter(MLWriter):
|
||||
instance: OneVsRestModel
|
||||
def __init__(self, instance: OneVsRestModel) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class OneVsRestModelReader(MLReader[OneVsRestModel]):
|
||||
cls: Type[OneVsRestModel]
|
||||
def __init__(self, cls: Type[OneVsRestModel]) -> None: ...
|
||||
def load(self, path: str) -> OneVsRestModel: ...
|
||||
|
||||
class FMClassifier(
|
||||
_JavaProbabilisticClassifier[FMClassificationModel],
|
||||
_FactorizationMachinesParams,
|
||||
|
|
|
@ -237,6 +237,11 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
self.assertEqual(len(m1.models), len(m2.models))
|
||||
for x, y in zip(m1.models, m2.models):
|
||||
self._compare_pipelines(x, y)
|
||||
elif isinstance(m1, Params):
|
||||
# Test on python backend Estimator/Transformer/Model/Evaluator
|
||||
self.assertEqual(len(m1.params), len(m2.params))
|
||||
for p in m1.params:
|
||||
self._compare_params(m1, m2, p)
|
||||
else:
|
||||
raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
|
||||
|
||||
|
@ -326,14 +331,14 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
except OSError:
|
||||
pass
|
||||
|
||||
def test_onevsrest(self):
|
||||
def _run_test_onevsrest(self, LogisticRegressionCls):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
df = self.spark.createDataFrame([(0.0, 0.5, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, 0.5, Vectors.sparse(2, [], [])),
|
||||
(2.0, 1.0, Vectors.dense(0.5, 0.5))] * 10,
|
||||
["label", "wt", "features"])
|
||||
|
||||
lr = LogisticRegression(maxIter=5, regParam=0.01)
|
||||
lr = LogisticRegressionCls(maxIter=5, regParam=0.01)
|
||||
ovr = OneVsRest(classifier=lr)
|
||||
|
||||
def reload_and_compare(ovr, suffix):
|
||||
|
@ -350,6 +355,11 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
reload_and_compare(OneVsRest(classifier=lr), "ovr")
|
||||
reload_and_compare(OneVsRest(classifier=lr).setWeightCol("wt"), "ovrw")
|
||||
|
||||
def test_onevsrest(self):
|
||||
from pyspark.testing.mlutils import DummyLogisticRegression
|
||||
self._run_test_onevsrest(LogisticRegression)
|
||||
self._run_test_onevsrest(DummyLogisticRegression)
|
||||
|
||||
def test_decisiontree_classifier(self):
|
||||
dt = DecisionTreeClassifier(maxDepth=1)
|
||||
path = tempfile.mkdtemp()
|
||||
|
|
|
@ -28,7 +28,8 @@ from pyspark.ml.param import Param, Params
|
|||
from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \
|
||||
TrainValidationSplit, TrainValidationSplitModel
|
||||
from pyspark.sql.functions import rand
|
||||
from pyspark.testing.mlutils import SparkSessionTestCase
|
||||
from pyspark.testing.mlutils import DummyEvaluator, DummyLogisticRegression, \
|
||||
DummyLogisticRegressionModel, SparkSessionTestCase
|
||||
|
||||
|
||||
class HasInducedError(Params):
|
||||
|
@ -201,7 +202,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
for v in param.values():
|
||||
assert(type(v) == float)
|
||||
|
||||
def test_save_load_trained_model(self):
|
||||
def _run_test_save_load_trained_model(self, LogisticRegressionCls, LogisticRegressionModelCls):
|
||||
# This tests saving and loading the trained model only.
|
||||
# Save/load for CrossValidator will be added later: SPARK-13786
|
||||
temp_path = tempfile.mkdtemp()
|
||||
|
@ -212,7 +213,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
(Vectors.dense([0.6]), 1.0),
|
||||
(Vectors.dense([1.0]), 1.0)] * 10,
|
||||
["features", "label"])
|
||||
lr = LogisticRegression()
|
||||
lr = LogisticRegressionCls()
|
||||
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
|
||||
evaluator = BinaryClassificationEvaluator()
|
||||
cv = CrossValidator(
|
||||
|
@ -228,7 +229,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
|
||||
lrModelPath = temp_path + "/lrModel"
|
||||
lrModel.save(lrModelPath)
|
||||
loadedLrModel = LogisticRegressionModel.load(lrModelPath)
|
||||
loadedLrModel = LogisticRegressionModelCls.load(lrModelPath)
|
||||
self.assertEqual(loadedLrModel.uid, lrModel.uid)
|
||||
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
|
||||
|
||||
|
@ -248,7 +249,12 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
loadedCvModel.isSet(param) for param in loadedCvModel.params
|
||||
))
|
||||
|
||||
def test_save_load_simple_estimator(self):
|
||||
def test_save_load_trained_model(self):
|
||||
self._run_test_save_load_trained_model(LogisticRegression, LogisticRegressionModel)
|
||||
self._run_test_save_load_trained_model(DummyLogisticRegression,
|
||||
DummyLogisticRegressionModel)
|
||||
|
||||
def _run_test_save_load_simple_estimator(self, LogisticRegressionCls, evaluatorCls):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
dataset = self.spark.createDataFrame(
|
||||
[(Vectors.dense([0.0]), 0.0),
|
||||
|
@ -258,9 +264,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
(Vectors.dense([1.0]), 1.0)] * 10,
|
||||
["features", "label"])
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr = LogisticRegressionCls()
|
||||
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
|
||||
evaluator = BinaryClassificationEvaluator()
|
||||
evaluator = evaluatorCls()
|
||||
|
||||
# test save/load of CrossValidator
|
||||
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
|
||||
|
@ -278,6 +284,12 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
loadedModel = CrossValidatorModel.load(cvModelPath)
|
||||
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
|
||||
|
||||
def test_save_load_simple_estimator(self):
|
||||
self._run_test_save_load_simple_estimator(
|
||||
LogisticRegression, BinaryClassificationEvaluator)
|
||||
self._run_test_save_load_simple_estimator(
|
||||
DummyLogisticRegression, DummyEvaluator)
|
||||
|
||||
def test_parallel_evaluation(self):
|
||||
dataset = self.spark.createDataFrame(
|
||||
[(Vectors.dense([0.0]), 0.0),
|
||||
|
@ -343,7 +355,7 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
for j in range(len(grid)):
|
||||
self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid)
|
||||
|
||||
def test_save_load_nested_estimator(self):
|
||||
def _run_test_save_load_nested_estimator(self, LogisticRegressionCls):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
dataset = self.spark.createDataFrame(
|
||||
[(Vectors.dense([0.0]), 0.0),
|
||||
|
@ -353,9 +365,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
(Vectors.dense([1.0]), 1.0)] * 10,
|
||||
["features", "label"])
|
||||
|
||||
ova = OneVsRest(classifier=LogisticRegression())
|
||||
lr1 = LogisticRegression().setMaxIter(100)
|
||||
lr2 = LogisticRegression().setMaxIter(150)
|
||||
ova = OneVsRest(classifier=LogisticRegressionCls())
|
||||
lr1 = LogisticRegressionCls().setMaxIter(100)
|
||||
lr2 = LogisticRegressionCls().setMaxIter(150)
|
||||
grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
|
||||
evaluator = MulticlassClassificationEvaluator()
|
||||
|
||||
|
@ -385,7 +397,11 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
|
||||
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
|
||||
|
||||
def test_save_load_pipeline_estimator(self):
|
||||
def test_save_load_nested_estimator(self):
|
||||
self._run_test_save_load_nested_estimator(LogisticRegression)
|
||||
self._run_test_save_load_nested_estimator(DummyLogisticRegression)
|
||||
|
||||
def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
training = self.spark.createDataFrame([
|
||||
(0, "a b c d e spark", 1.0),
|
||||
|
@ -402,9 +418,9 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
tokenizer = Tokenizer(inputCol="text", outputCol="words")
|
||||
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
|
||||
|
||||
ova = OneVsRest(classifier=LogisticRegression())
|
||||
lr1 = LogisticRegression().setMaxIter(5)
|
||||
lr2 = LogisticRegression().setMaxIter(10)
|
||||
ova = OneVsRest(classifier=LogisticRegressionCls())
|
||||
lr1 = LogisticRegressionCls().setMaxIter(5)
|
||||
lr2 = LogisticRegressionCls().setMaxIter(10)
|
||||
|
||||
pipeline = Pipeline(stages=[tokenizer, hashingTF, ova])
|
||||
|
||||
|
@ -464,6 +480,10 @@ class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
original_nested_pipeline_model.stages):
|
||||
self.assertEqual(loadedStage.uid, originalStage.uid)
|
||||
|
||||
def test_save_load_pipeline_estimator(self):
|
||||
self._run_test_save_load_pipeline_estimator(LogisticRegression)
|
||||
self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
|
||||
|
||||
def test_user_specified_folds(self):
|
||||
from pyspark.sql import functions as F
|
||||
|
||||
|
@ -593,7 +613,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
"validationMetrics has the same size of grid parameter")
|
||||
self.assertEqual(1.0, max(validationMetrics))
|
||||
|
||||
def test_save_load_trained_model(self):
|
||||
def _run_test_save_load_trained_model(self, LogisticRegressionCls, LogisticRegressionModelCls):
|
||||
# This tests saving and loading the trained model only.
|
||||
# Save/load for TrainValidationSplit will be added later: SPARK-13786
|
||||
temp_path = tempfile.mkdtemp()
|
||||
|
@ -604,7 +624,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
(Vectors.dense([0.6]), 1.0),
|
||||
(Vectors.dense([1.0]), 1.0)] * 10,
|
||||
["features", "label"])
|
||||
lr = LogisticRegression()
|
||||
lr = LogisticRegressionCls()
|
||||
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
|
||||
evaluator = BinaryClassificationEvaluator()
|
||||
tvs = TrainValidationSplit(
|
||||
|
@ -619,7 +639,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
|
||||
lrModelPath = temp_path + "/lrModel"
|
||||
lrModel.save(lrModelPath)
|
||||
loadedLrModel = LogisticRegressionModel.load(lrModelPath)
|
||||
loadedLrModel = LogisticRegressionModelCls.load(lrModelPath)
|
||||
self.assertEqual(loadedLrModel.uid, lrModel.uid)
|
||||
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
|
||||
|
||||
|
@ -636,7 +656,12 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
loadedTvsModel.isSet(param) for param in loadedTvsModel.params
|
||||
))
|
||||
|
||||
def test_save_load_simple_estimator(self):
|
||||
def test_save_load_trained_model(self):
|
||||
self._run_test_save_load_trained_model(LogisticRegression, LogisticRegressionModel)
|
||||
self._run_test_save_load_trained_model(DummyLogisticRegression,
|
||||
DummyLogisticRegressionModel)
|
||||
|
||||
def _run_test_save_load_simple_estimator(self, LogisticRegressionCls, evaluatorCls):
|
||||
# This tests saving and loading the trained model only.
|
||||
# Save/load for TrainValidationSplit will be added later: SPARK-13786
|
||||
temp_path = tempfile.mkdtemp()
|
||||
|
@ -647,9 +672,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
(Vectors.dense([0.6]), 1.0),
|
||||
(Vectors.dense([1.0]), 1.0)] * 10,
|
||||
["features", "label"])
|
||||
lr = LogisticRegression()
|
||||
lr = LogisticRegressionCls()
|
||||
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
|
||||
evaluator = BinaryClassificationEvaluator()
|
||||
evaluator = evaluatorCls()
|
||||
tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
|
||||
tvsModel = tvs.fit(dataset)
|
||||
|
||||
|
@ -666,6 +691,12 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
loadedModel = TrainValidationSplitModel.load(tvsModelPath)
|
||||
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
|
||||
|
||||
def test_save_load_simple_estimator(self):
|
||||
self._run_test_save_load_simple_estimator(
|
||||
LogisticRegression, BinaryClassificationEvaluator)
|
||||
self._run_test_save_load_simple_estimator(
|
||||
DummyLogisticRegression, DummyEvaluator)
|
||||
|
||||
def test_parallel_evaluation(self):
|
||||
dataset = self.spark.createDataFrame(
|
||||
[(Vectors.dense([0.0]), 0.0),
|
||||
|
@ -718,7 +749,7 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
for i in range(len(grid)):
|
||||
self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)
|
||||
|
||||
def test_save_load_nested_estimator(self):
|
||||
def _run_test_save_load_nested_estimator(self, LogisticRegressionCls):
|
||||
# This tests saving and loading the trained model only.
|
||||
# Save/load for TrainValidationSplit will be added later: SPARK-13786
|
||||
temp_path = tempfile.mkdtemp()
|
||||
|
@ -729,9 +760,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
(Vectors.dense([0.6]), 1.0),
|
||||
(Vectors.dense([1.0]), 1.0)] * 10,
|
||||
["features", "label"])
|
||||
ova = OneVsRest(classifier=LogisticRegression())
|
||||
lr1 = LogisticRegression().setMaxIter(100)
|
||||
lr2 = LogisticRegression().setMaxIter(150)
|
||||
ova = OneVsRest(classifier=LogisticRegressionCls())
|
||||
lr1 = LogisticRegressionCls().setMaxIter(100)
|
||||
lr2 = LogisticRegressionCls().setMaxIter(150)
|
||||
grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
|
||||
evaluator = MulticlassClassificationEvaluator()
|
||||
|
||||
|
@ -759,7 +790,11 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
|
||||
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
|
||||
|
||||
def test_save_load_pipeline_estimator(self):
|
||||
def test_save_load_nested_estimator(self):
|
||||
self._run_test_save_load_nested_estimator(LogisticRegression)
|
||||
self._run_test_save_load_nested_estimator(DummyLogisticRegression)
|
||||
|
||||
def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
|
||||
temp_path = tempfile.mkdtemp()
|
||||
training = self.spark.createDataFrame([
|
||||
(0, "a b c d e spark", 1.0),
|
||||
|
@ -776,9 +811,9 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
tokenizer = Tokenizer(inputCol="text", outputCol="words")
|
||||
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
|
||||
|
||||
ova = OneVsRest(classifier=LogisticRegression())
|
||||
lr1 = LogisticRegression().setMaxIter(5)
|
||||
lr2 = LogisticRegression().setMaxIter(10)
|
||||
ova = OneVsRest(classifier=LogisticRegressionCls())
|
||||
lr1 = LogisticRegressionCls().setMaxIter(5)
|
||||
lr2 = LogisticRegressionCls().setMaxIter(10)
|
||||
|
||||
pipeline = Pipeline(stages=[tokenizer, hashingTF, ova])
|
||||
|
||||
|
@ -836,6 +871,10 @@ class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
|
|||
original_nested_pipeline_model.stages):
|
||||
self.assertEqual(loadedStage.uid, originalStage.uid)
|
||||
|
||||
def test_save_load_pipeline_estimator(self):
|
||||
self._run_test_save_load_pipeline_estimator(LogisticRegression)
|
||||
self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
|
||||
|
||||
def test_copy(self):
|
||||
dataset = self.spark.createDataFrame([
|
||||
(10, 10.0),
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
import itertools
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
@ -22,12 +23,13 @@ from multiprocessing.pool import ThreadPool
|
|||
import numpy as np
|
||||
|
||||
from pyspark import keyword_only, since, SparkContext
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.common import _py2java, _java2py
|
||||
from pyspark.ml import Estimator, Transformer, Model
|
||||
from pyspark.ml.common import inherit_doc, _py2java, _java2py
|
||||
from pyspark.ml.evaluation import Evaluator
|
||||
from pyspark.ml.param import Params, Param, TypeConverters
|
||||
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
|
||||
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
|
||||
MetaAlgorithmReadWrite
|
||||
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, \
|
||||
MLReadable, MLReader, MLWritable, MLWriter, JavaMLReader, JavaMLWriter
|
||||
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
|
||||
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
|
||||
from pyspark.sql.types import BooleanType
|
||||
|
@ -229,6 +231,7 @@ class _ValidatorParams(HasSeed):
|
|||
|
||||
|
||||
class _ValidatorSharedReadWrite:
|
||||
|
||||
@staticmethod
|
||||
def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps):
|
||||
pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
|
||||
|
@ -275,10 +278,8 @@ class _ValidatorSharedReadWrite:
|
|||
raise ValueError('Resolve param in estimatorParamMaps failed: ' +
|
||||
javaParam.parent() + '.' + javaParam.name())
|
||||
javaValue = javaPair.value()
|
||||
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(javaValue):
|
||||
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
|
||||
# and Estimator/Transformer class which implements `_from_java` static method
|
||||
# (such as OneVsRest, Pipeline class).
|
||||
if sc._jvm.Class.forName("org.apache.spark.ml.util.DefaultParamsWritable") \
|
||||
.isInstance(javaValue):
|
||||
pyValue = JavaParams._from_java(javaValue)
|
||||
else:
|
||||
pyValue = _java2py(sc, javaValue)
|
||||
|
@ -286,6 +287,222 @@ class _ValidatorSharedReadWrite:
|
|||
pyParamMaps.append(pyParamMap)
|
||||
return pyParamMaps
|
||||
|
||||
@staticmethod
|
||||
def is_java_convertible(instance):
|
||||
allNestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance.getEstimator())
|
||||
evaluator_convertible = isinstance(instance.getEvaluator(), JavaParams)
|
||||
estimator_convertible = all(map(lambda stage: hasattr(stage, '_to_java'), allNestedStages))
|
||||
return estimator_convertible and evaluator_convertible
|
||||
|
||||
@staticmethod
|
||||
def saveImpl(path, instance, sc, extraMetadata=None):
|
||||
numParamsNotJson = 0
|
||||
jsonEstimatorParamMaps = []
|
||||
for paramMap in instance.getEstimatorParamMaps():
|
||||
jsonParamMap = []
|
||||
for p, v in paramMap.items():
|
||||
jsonParam = {'parent': p.parent, 'name': p.name}
|
||||
if (isinstance(v, Estimator) and not MetaAlgorithmReadWrite.isMetaEstimator(v)) \
|
||||
or isinstance(v, Transformer) or isinstance(v, Evaluator):
|
||||
relative_path = f'epm_{p.name}{numParamsNotJson}'
|
||||
param_path = os.path.join(path, relative_path)
|
||||
numParamsNotJson += 1
|
||||
v.save(param_path)
|
||||
jsonParam['value'] = relative_path
|
||||
jsonParam['isJson'] = False
|
||||
elif isinstance(v, MLWritable):
|
||||
raise RuntimeError(
|
||||
"ValidatorSharedReadWrite.saveImpl does not handle parameters of type: "
|
||||
"MLWritable that are not Estimaor/Evaluator/Transformer, and if parameter "
|
||||
"is estimator, it cannot be meta estimator such as Validator or OneVsRest")
|
||||
else:
|
||||
jsonParam['value'] = v
|
||||
jsonParam['isJson'] = True
|
||||
jsonParamMap.append(jsonParam)
|
||||
jsonEstimatorParamMaps.append(jsonParamMap)
|
||||
|
||||
skipParams = ['estimator', 'evaluator', 'estimatorParamMaps']
|
||||
jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
|
||||
jsonParams['estimatorParamMaps'] = jsonEstimatorParamMaps
|
||||
|
||||
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, jsonParams)
|
||||
evaluatorPath = os.path.join(path, 'evaluator')
|
||||
instance.getEvaluator().save(evaluatorPath)
|
||||
estimatorPath = os.path.join(path, 'estimator')
|
||||
instance.getEstimator().save(estimatorPath)
|
||||
|
||||
@staticmethod
|
||||
def load(path, sc, metadata):
|
||||
evaluatorPath = os.path.join(path, 'evaluator')
|
||||
evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc)
|
||||
estimatorPath = os.path.join(path, 'estimator')
|
||||
estimator = DefaultParamsReader.loadParamsInstance(estimatorPath, sc)
|
||||
|
||||
uidToParams = MetaAlgorithmReadWrite.getUidMap(estimator)
|
||||
uidToParams[evaluator.uid] = evaluator
|
||||
|
||||
jsonEstimatorParamMaps = metadata['paramMap']['estimatorParamMaps']
|
||||
|
||||
estimatorParamMaps = []
|
||||
for jsonParamMap in jsonEstimatorParamMaps:
|
||||
paramMap = {}
|
||||
for jsonParam in jsonParamMap:
|
||||
est = uidToParams[jsonParam['parent']]
|
||||
param = getattr(est, jsonParam['name'])
|
||||
if 'isJson' not in jsonParam or ('isJson' in jsonParam and jsonParam['isJson']):
|
||||
value = jsonParam['value']
|
||||
else:
|
||||
relativePath = jsonParam['value']
|
||||
valueSavedPath = os.path.join(path, relativePath)
|
||||
value = DefaultParamsReader.loadParamsInstance(valueSavedPath, sc)
|
||||
paramMap[param] = value
|
||||
estimatorParamMaps.append(paramMap)
|
||||
|
||||
return metadata, estimator, evaluator, estimatorParamMaps
|
||||
|
||||
@staticmethod
|
||||
def validateParams(instance):
|
||||
estiamtor = instance.getEstimator()
|
||||
evaluator = instance.getEvaluator()
|
||||
uidMap = MetaAlgorithmReadWrite.getUidMap(estiamtor)
|
||||
|
||||
for elem in [evaluator] + list(uidMap.values()):
|
||||
if not isinstance(elem, MLWritable):
|
||||
raise ValueError(f'Validator write will fail because it contains {elem.uid} '
|
||||
f'which is not writable.')
|
||||
|
||||
estimatorParamMaps = instance.getEstimatorParamMaps()
|
||||
paramErr = 'Validator save requires all Params in estimatorParamMaps to apply to ' \
|
||||
f'its Estimator, An extraneous Param was found: '
|
||||
for paramMap in estimatorParamMaps:
|
||||
for param in paramMap:
|
||||
if param.parent not in uidMap:
|
||||
raise ValueError(paramErr + repr(param))
|
||||
|
||||
@staticmethod
|
||||
def getValidatorModelWriterPersistSubModelsParam(writer):
|
||||
if 'persistsubmodels' in writer.optionMap:
|
||||
persistSubModelsParam = writer.optionMap['persistsubmodels'].lower()
|
||||
if persistSubModelsParam == 'true':
|
||||
return True
|
||||
elif persistSubModelsParam == 'false':
|
||||
return False
|
||||
else:
|
||||
raise ValueError(
|
||||
f'persistSubModels option value {persistSubModelsParam} is invalid, '
|
||||
f"the possible values are True, 'True' or False, 'False'")
|
||||
else:
|
||||
return writer.instance.subModels is not None
|
||||
|
||||
|
||||
_save_with_persist_submodels_no_submodels_found_err = \
|
||||
'When persisting tuning models, you can only set persistSubModels to true if the tuning ' \
|
||||
'was done with collectSubModels set to true. To save the sub-models, try rerunning fitting ' \
|
||||
'with collectSubModels set to true.'
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class CrossValidatorReader(MLReader):
|
||||
|
||||
def __init__(self, cls):
|
||||
super(CrossValidatorReader, self).__init__()
|
||||
self.cls = cls
|
||||
|
||||
def load(self, path):
|
||||
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
||||
if not DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
return JavaMLReader(self.cls).load(path)
|
||||
else:
|
||||
metadata, estimator, evaluator, estimatorParamMaps = \
|
||||
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
|
||||
cv = CrossValidator(estimator=estimator,
|
||||
estimatorParamMaps=estimatorParamMaps,
|
||||
evaluator=evaluator)
|
||||
cv = cv._resetUid(metadata['uid'])
|
||||
DefaultParamsReader.getAndSetParams(cv, metadata, skipParams=['estimatorParamMaps'])
|
||||
return cv
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class CrossValidatorWriter(MLWriter):
|
||||
|
||||
def __init__(self, instance):
|
||||
super(CrossValidatorWriter, self).__init__()
|
||||
self.instance = instance
|
||||
|
||||
def saveImpl(self, path):
|
||||
_ValidatorSharedReadWrite.validateParams(self.instance)
|
||||
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class CrossValidatorModelReader(MLReader):
|
||||
|
||||
def __init__(self, cls):
|
||||
super(CrossValidatorModelReader, self).__init__()
|
||||
self.cls = cls
|
||||
|
||||
def load(self, path):
|
||||
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
||||
if not DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
return JavaMLReader(self.cls).load(path)
|
||||
else:
|
||||
metadata, estimator, evaluator, estimatorParamMaps = \
|
||||
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
|
||||
numFolds = metadata['paramMap']['numFolds']
|
||||
bestModelPath = os.path.join(path, 'bestModel')
|
||||
bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
|
||||
avgMetrics = metadata['avgMetrics']
|
||||
persistSubModels = ('persistSubModels' in metadata) and metadata['persistSubModels']
|
||||
|
||||
if persistSubModels:
|
||||
subModels = [[None] * len(estimatorParamMaps)] * numFolds
|
||||
for splitIndex in range(numFolds):
|
||||
for paramIndex in range(len(estimatorParamMaps)):
|
||||
modelPath = os.path.join(
|
||||
path, 'subModels', f'fold{splitIndex}', f'{paramIndex}')
|
||||
subModels[splitIndex][paramIndex] = \
|
||||
DefaultParamsReader.loadParamsInstance(modelPath, self.sc)
|
||||
else:
|
||||
subModels = None
|
||||
|
||||
cvModel = CrossValidatorModel(bestModel, avgMetrics=avgMetrics, subModels=subModels)
|
||||
cvModel = cvModel._resetUid(metadata['uid'])
|
||||
cvModel.set(cvModel.estimator, estimator)
|
||||
cvModel.set(cvModel.estimatorParamMaps, estimatorParamMaps)
|
||||
cvModel.set(cvModel.evaluator, evaluator)
|
||||
DefaultParamsReader.getAndSetParams(
|
||||
cvModel, metadata, skipParams=['estimatorParamMaps'])
|
||||
return cvModel
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class CrossValidatorModelWriter(MLWriter):
|
||||
|
||||
def __init__(self, instance):
|
||||
super(CrossValidatorModelWriter, self).__init__()
|
||||
self.instance = instance
|
||||
|
||||
def saveImpl(self, path):
|
||||
_ValidatorSharedReadWrite.validateParams(self.instance)
|
||||
instance = self.instance
|
||||
persistSubModels = _ValidatorSharedReadWrite \
|
||||
.getValidatorModelWriterPersistSubModelsParam(self)
|
||||
extraMetadata = {'avgMetrics': instance.avgMetrics,
|
||||
'persistSubModels': persistSubModels}
|
||||
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
|
||||
bestModelPath = os.path.join(path, 'bestModel')
|
||||
instance.bestModel.save(bestModelPath)
|
||||
if persistSubModels:
|
||||
if instance.subModels is None:
|
||||
raise ValueError(_save_with_persist_submodels_no_submodels_found_err)
|
||||
subModelsPath = os.path.join(path, 'subModels')
|
||||
for splitIndex in range(instance.getNumFolds()):
|
||||
splitPath = os.path.join(subModelsPath, f'fold{splitIndex}')
|
||||
for paramIndex in range(len(instance.getEstimatorParamMaps())):
|
||||
modelPath = os.path.join(splitPath, f'{paramIndex}')
|
||||
instance.subModels[splitIndex][paramIndex].save(modelPath)
|
||||
|
||||
|
||||
class _CrossValidatorParams(_ValidatorParams):
|
||||
"""
|
||||
|
@ -553,13 +770,15 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec
|
|||
@since("2.3.0")
|
||||
def write(self):
|
||||
"""Returns an MLWriter instance for this ML instance."""
|
||||
return JavaMLWriter(self)
|
||||
if _ValidatorSharedReadWrite.is_java_convertible(self):
|
||||
return JavaMLWriter(self)
|
||||
return CrossValidatorWriter(self)
|
||||
|
||||
@classmethod
|
||||
@since("2.3.0")
|
||||
def read(cls):
|
||||
"""Returns an MLReader instance for this class."""
|
||||
return JavaMLReader(cls)
|
||||
return CrossValidatorReader(cls)
|
||||
|
||||
@classmethod
|
||||
def _from_java(cls, java_stage):
|
||||
|
@ -662,13 +881,15 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
|
|||
@since("2.3.0")
|
||||
def write(self):
|
||||
"""Returns an MLWriter instance for this ML instance."""
|
||||
return JavaMLWriter(self)
|
||||
if _ValidatorSharedReadWrite.is_java_convertible(self):
|
||||
return JavaMLWriter(self)
|
||||
return CrossValidatorModelWriter(self)
|
||||
|
||||
@classmethod
|
||||
@since("2.3.0")
|
||||
def read(cls):
|
||||
"""Returns an MLReader instance for this class."""
|
||||
return JavaMLReader(cls)
|
||||
return CrossValidatorModelReader(cls)
|
||||
|
||||
@classmethod
|
||||
def _from_java(cls, java_stage):
|
||||
|
@ -738,6 +959,106 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
|
|||
return _java_obj
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class TrainValidationSplitReader(MLReader):
|
||||
|
||||
def __init__(self, cls):
|
||||
super(TrainValidationSplitReader, self).__init__()
|
||||
self.cls = cls
|
||||
|
||||
def load(self, path):
|
||||
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
||||
if not DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
return JavaMLReader(self.cls).load(path)
|
||||
else:
|
||||
metadata, estimator, evaluator, estimatorParamMaps = \
|
||||
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
|
||||
tvs = TrainValidationSplit(estimator=estimator,
|
||||
estimatorParamMaps=estimatorParamMaps,
|
||||
evaluator=evaluator)
|
||||
tvs = tvs._resetUid(metadata['uid'])
|
||||
DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams=['estimatorParamMaps'])
|
||||
return tvs
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class TrainValidationSplitWriter(MLWriter):
|
||||
|
||||
def __init__(self, instance):
|
||||
super(TrainValidationSplitWriter, self).__init__()
|
||||
self.instance = instance
|
||||
|
||||
def saveImpl(self, path):
|
||||
_ValidatorSharedReadWrite.validateParams(self.instance)
|
||||
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class TrainValidationSplitModelReader(MLReader):
|
||||
|
||||
def __init__(self, cls):
|
||||
super(TrainValidationSplitModelReader, self).__init__()
|
||||
self.cls = cls
|
||||
|
||||
def load(self, path):
|
||||
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
||||
if not DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
return JavaMLReader(self.cls).load(path)
|
||||
else:
|
||||
metadata, estimator, evaluator, estimatorParamMaps = \
|
||||
_ValidatorSharedReadWrite.load(path, self.sc, metadata)
|
||||
bestModelPath = os.path.join(path, 'bestModel')
|
||||
bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
|
||||
validationMetrics = metadata['validationMetrics']
|
||||
persistSubModels = ('persistSubModels' in metadata) and metadata['persistSubModels']
|
||||
|
||||
if persistSubModels:
|
||||
subModels = [None] * len(estimatorParamMaps)
|
||||
for paramIndex in range(len(estimatorParamMaps)):
|
||||
modelPath = os.path.join(path, 'subModels', f'{paramIndex}')
|
||||
subModels[paramIndex] = \
|
||||
DefaultParamsReader.loadParamsInstance(modelPath, self.sc)
|
||||
else:
|
||||
subModels = None
|
||||
|
||||
tvsModel = TrainValidationSplitModel(
|
||||
bestModel, validationMetrics=validationMetrics, subModels=subModels)
|
||||
tvsModel = tvsModel._resetUid(metadata['uid'])
|
||||
tvsModel.set(tvsModel.estimator, estimator)
|
||||
tvsModel.set(tvsModel.estimatorParamMaps, estimatorParamMaps)
|
||||
tvsModel.set(tvsModel.evaluator, evaluator)
|
||||
DefaultParamsReader.getAndSetParams(
|
||||
tvsModel, metadata, skipParams=['estimatorParamMaps'])
|
||||
return tvsModel
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class TrainValidationSplitModelWriter(MLWriter):
|
||||
|
||||
def __init__(self, instance):
|
||||
super(TrainValidationSplitModelWriter, self).__init__()
|
||||
self.instance = instance
|
||||
|
||||
def saveImpl(self, path):
|
||||
_ValidatorSharedReadWrite.validateParams(self.instance)
|
||||
instance = self.instance
|
||||
persistSubModels = _ValidatorSharedReadWrite \
|
||||
.getValidatorModelWriterPersistSubModelsParam(self)
|
||||
|
||||
extraMetadata = {'validationMetrics': instance.validationMetrics,
|
||||
'persistSubModels': persistSubModels}
|
||||
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
|
||||
bestModelPath = os.path.join(path, 'bestModel')
|
||||
instance.bestModel.save(bestModelPath)
|
||||
if persistSubModels:
|
||||
if instance.subModels is None:
|
||||
raise ValueError(_save_with_persist_submodels_no_submodels_found_err)
|
||||
subModelsPath = os.path.join(path, 'subModels')
|
||||
for paramIndex in range(len(instance.getEstimatorParamMaps())):
|
||||
modelPath = os.path.join(subModelsPath, f'{paramIndex}')
|
||||
instance.subModels[paramIndex].save(modelPath)
|
||||
|
||||
|
||||
class _TrainValidationSplitParams(_ValidatorParams):
|
||||
"""
|
||||
Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.
|
||||
|
@ -942,13 +1263,15 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis
|
|||
@since("2.3.0")
|
||||
def write(self):
|
||||
"""Returns an MLWriter instance for this ML instance."""
|
||||
return JavaMLWriter(self)
|
||||
if _ValidatorSharedReadWrite.is_java_convertible(self):
|
||||
return JavaMLWriter(self)
|
||||
return TrainValidationSplitWriter(self)
|
||||
|
||||
@classmethod
|
||||
@since("2.3.0")
|
||||
def read(cls):
|
||||
"""Returns an MLReader instance for this class."""
|
||||
return JavaMLReader(cls)
|
||||
return TrainValidationSplitReader(cls)
|
||||
|
||||
@classmethod
|
||||
def _from_java(cls, java_stage):
|
||||
|
@ -1046,13 +1369,15 @@ class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable,
|
|||
@since("2.3.0")
|
||||
def write(self):
|
||||
"""Returns an MLWriter instance for this ML instance."""
|
||||
return JavaMLWriter(self)
|
||||
if _ValidatorSharedReadWrite.is_java_convertible(self):
|
||||
return JavaMLWriter(self)
|
||||
return TrainValidationSplitModelWriter(self)
|
||||
|
||||
@classmethod
|
||||
@since("2.3.0")
|
||||
def read(cls):
|
||||
"""Returns an MLReader instance for this class."""
|
||||
return JavaMLReader(cls)
|
||||
return TrainValidationSplitModelReader(cls)
|
||||
|
||||
@classmethod
|
||||
def _from_java(cls, java_stage):
|
||||
|
|
|
@ -183,3 +183,43 @@ class TrainValidationSplitModel(
|
|||
def write(self) -> MLWriter: ...
|
||||
@classmethod
|
||||
def read(cls: Type[TrainValidationSplitModel]) -> MLReader: ...
|
||||
|
||||
class CrossValidatorWriter(MLWriter):
|
||||
instance: CrossValidator
|
||||
def __init__(self, instance: CrossValidator) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class CrossValidatorReader(MLReader[CrossValidator]):
|
||||
cls: Type[CrossValidator]
|
||||
def __init__(self, cls: Type[CrossValidator]) -> None: ...
|
||||
def load(self, path: str) -> CrossValidator: ...
|
||||
|
||||
class CrossValidatorModelWriter(MLWriter):
|
||||
instance: CrossValidatorModel
|
||||
def __init__(self, instance: CrossValidatorModel) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class CrossValidatorModelReader(MLReader[CrossValidatorModel]):
|
||||
cls: Type[CrossValidatorModel]
|
||||
def __init__(self, cls: Type[CrossValidatorModel]) -> None: ...
|
||||
def load(self, path: str) -> CrossValidatorModel: ...
|
||||
|
||||
class TrainValidationSplitWriter(MLWriter):
|
||||
instance: TrainValidationSplit
|
||||
def __init__(self, instance: TrainValidationSplit) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class TrainValidationSplitReader(MLReader[TrainValidationSplit]):
|
||||
cls: Type[TrainValidationSplit]
|
||||
def __init__(self, cls: Type[TrainValidationSplit]) -> None: ...
|
||||
def load(self, path: str) -> TrainValidationSplit: ...
|
||||
|
||||
class TrainValidationSplitModelWriter(MLWriter):
|
||||
instance: TrainValidationSplitModel
|
||||
def __init__(self, instance: TrainValidationSplitModel) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class TrainValidationSplitModelReader(MLReader[TrainValidationSplitModel]):
|
||||
cls: Type[TrainValidationSplitModel]
|
||||
def __init__(self, cls: Type[TrainValidationSplitModel]) -> None: ...
|
||||
def load(self, path: str) -> TrainValidationSplitModel: ...
|
||||
|
|
|
@ -106,6 +106,7 @@ class MLWriter(BaseReadWrite):
|
|||
def __init__(self):
|
||||
super(MLWriter, self).__init__()
|
||||
self.shouldOverwrite = False
|
||||
self.optionMap = {}
|
||||
|
||||
def _handleOverwrite(self, path):
|
||||
from pyspark.ml.wrapper import JavaWrapper
|
||||
|
@ -132,6 +133,14 @@ class MLWriter(BaseReadWrite):
|
|||
self.shouldOverwrite = True
|
||||
return self
|
||||
|
||||
def option(self, key, value):
|
||||
"""
|
||||
Adds an option to the underlying MLWriter. See the documentation for the specific model's
|
||||
writer for possible options. The option name (key) is case-insensitive.
|
||||
"""
|
||||
self.optionMap[key.lower()] = str(value)
|
||||
return self
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class GeneralMLWriter(MLWriter):
|
||||
|
@ -375,6 +384,13 @@ class DefaultParamsWriter(MLWriter):
|
|||
def saveImpl(self, path):
|
||||
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
|
||||
|
||||
@staticmethod
|
||||
def extractJsonParams(instance, skipParams):
|
||||
paramMap = instance.extractParamMap()
|
||||
jsonParams = {param.name: value for param, value in paramMap.items()
|
||||
if param.name not in skipParams}
|
||||
return jsonParams
|
||||
|
||||
@staticmethod
|
||||
def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
|
||||
"""
|
||||
|
@ -530,15 +546,16 @@ class DefaultParamsReader(MLReader):
|
|||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def getAndSetParams(instance, metadata):
|
||||
def getAndSetParams(instance, metadata, skipParams=None):
|
||||
"""
|
||||
Extract Params from metadata, and set them in the instance.
|
||||
"""
|
||||
# Set user-supplied param values
|
||||
for paramName in metadata['paramMap']:
|
||||
param = instance.getParam(paramName)
|
||||
paramValue = metadata['paramMap'][paramName]
|
||||
instance.set(param, paramValue)
|
||||
if skipParams is None or paramName not in skipParams:
|
||||
paramValue = metadata['paramMap'][paramName]
|
||||
instance.set(param, paramValue)
|
||||
|
||||
# Set default param values
|
||||
majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
|
||||
|
@ -554,6 +571,10 @@ class DefaultParamsReader(MLReader):
|
|||
paramValue = metadata['defaultParamMap'][paramName]
|
||||
instance._setDefault(**{paramName: paramValue})
|
||||
|
||||
@staticmethod
|
||||
def isPythonParamsInstance(metadata):
|
||||
return metadata['class'].startswith('pyspark.ml.')
|
||||
|
||||
@staticmethod
|
||||
def loadParamsInstance(path, sc):
|
||||
"""
|
||||
|
@ -561,7 +582,10 @@ class DefaultParamsReader(MLReader):
|
|||
This assumes the instance inherits from :py:class:`MLReadable`.
|
||||
"""
|
||||
metadata = DefaultParamsReader.loadMetadata(path, sc)
|
||||
pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
|
||||
if DefaultParamsReader.isPythonParamsInstance(metadata):
|
||||
pythonClassName = metadata['class']
|
||||
else:
|
||||
pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
|
||||
py_type = DefaultParamsReader.__get_class(pythonClassName)
|
||||
instance = py_type.load(path)
|
||||
return instance
|
||||
|
@ -630,3 +654,13 @@ class MetaAlgorithmReadWrite:
|
|||
nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))
|
||||
|
||||
return [pyInstance] + nestedStages
|
||||
|
||||
@staticmethod
|
||||
def getUidMap(instance):
|
||||
nestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance)
|
||||
uidMap = {stage.uid: stage for stage in nestedStages}
|
||||
if len(nestedStages) != len(uidMap):
|
||||
raise RuntimeError(f'{instance.__class__.__module__}.{instance.__class__.__name__}'
|
||||
f'.load found a compound estimator with stages with duplicate '
|
||||
f'UIDs. List of UIDs: {list(uidMap.keys())}.')
|
||||
return uidMap
|
||||
|
|
|
@ -132,3 +132,5 @@ class MetaAlgorithmReadWrite:
|
|||
def isMetaEstimator(pyInstance: Any) -> bool: ...
|
||||
@staticmethod
|
||||
def getAllNestedStages(pyInstance: Any) -> list: ...
|
||||
@staticmethod
|
||||
def getUidMap(instance: Any) -> dict: ...
|
||||
|
|
|
@ -17,8 +17,12 @@
|
|||
|
||||
import numpy as np
|
||||
|
||||
from pyspark import keyword_only
|
||||
from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
|
||||
from pyspark.ml.evaluation import Evaluator
|
||||
from pyspark.ml.param import Param, Params, TypeConverters
|
||||
from pyspark.ml.param.shared import HasMaxIter, HasRegParam
|
||||
from pyspark.ml.classification import Classifier, ClassificationModel
|
||||
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
|
||||
from pyspark.ml.wrapper import _java2py # type: ignore
|
||||
from pyspark.sql import DataFrame, SparkSession
|
||||
|
@ -161,3 +165,86 @@ class MockEstimator(Estimator, HasFake):
|
|||
|
||||
class MockModel(MockTransformer, Model, HasFake):
|
||||
pass
|
||||
|
||||
|
||||
class _DummyLogisticRegressionParams(HasMaxIter, HasRegParam):
|
||||
def setMaxIter(self, value):
|
||||
return self._set(maxIter=value)
|
||||
|
||||
def setRegParam(self, value):
|
||||
return self._set(regParam=value)
|
||||
|
||||
|
||||
# This is a dummy LogisticRegression used in test for python backend estimator/model
|
||||
class DummyLogisticRegression(Classifier, _DummyLogisticRegressionParams,
|
||||
DefaultParamsReadable, DefaultParamsWritable):
|
||||
@keyword_only
|
||||
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"):
|
||||
super(DummyLogisticRegression, self).__init__()
|
||||
kwargs = self._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, rawPredictionCol="rawPrediction"):
|
||||
kwargs = self._input_kwargs
|
||||
self._set(**kwargs)
|
||||
return self
|
||||
|
||||
def _fit(self, dataset):
|
||||
# Do nothing but create a dummy model
|
||||
return self._copyValues(DummyLogisticRegressionModel())
|
||||
|
||||
|
||||
class DummyLogisticRegressionModel(ClassificationModel, _DummyLogisticRegressionParams,
|
||||
DefaultParamsReadable, DefaultParamsWritable):
|
||||
|
||||
def __init__(self):
|
||||
super(DummyLogisticRegressionModel, self).__init__()
|
||||
|
||||
def _transform(self, dataset):
|
||||
# A dummy transform impl which always predict label 1
|
||||
from pyspark.sql.functions import array, lit
|
||||
from pyspark.ml.functions import array_to_vector
|
||||
rawPredCol = self.getRawPredictionCol()
|
||||
if rawPredCol:
|
||||
dataset = dataset.withColumn(
|
||||
rawPredCol, array_to_vector(array(lit(-100.0), lit(100.0))))
|
||||
predCol = self.getPredictionCol()
|
||||
if predCol:
|
||||
dataset = dataset.withColumn(predCol, lit(1.0))
|
||||
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def numClasses(self):
|
||||
# a dummy implementation for test.
|
||||
return 2
|
||||
|
||||
@property
|
||||
def intercept(self):
|
||||
# a dummy implementation for test.
|
||||
return 0.0
|
||||
|
||||
# This class only used in test. The following methods/properties are not used in tests.
|
||||
|
||||
@property
|
||||
def coefficients(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def predictRaw(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def numFeatures(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyEvaluator(Evaluator, DefaultParamsReadable, DefaultParamsWritable):
|
||||
|
||||
def _evaluate(self, dataset):
|
||||
# a dummy implementation for test.
|
||||
return 1.0
|
||||
|
|
Loading…
Reference in a new issue