9fcf0ea718
Disallow the use of unused imports: - Unnecessary increases the memory footprint of the application - Removes the imports that are required for the examples in the docstring from the file-scope to the example itself. This keeps the files itself clean, and gives a more complete example as it also includes the imports :) ``` fokkodriesprongFan spark % flake8 python | grep -i "imported but unused" python/pyspark/cloudpickle.py:46:1: F401 'functools.partial' imported but unused python/pyspark/cloudpickle.py:55:1: F401 'traceback' imported but unused python/pyspark/heapq3.py:868:5: F401 '_heapq.*' imported but unused python/pyspark/__init__.py:61:1: F401 'pyspark.version.__version__' imported but unused python/pyspark/__init__.py:62:1: F401 'pyspark._globals._NoValue' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.SQLContext' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.HiveContext' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.Row' imported but unused python/pyspark/rdd.py:21:1: F401 're' imported but unused python/pyspark/rdd.py:29:1: F401 'tempfile.NamedTemporaryFile' imported but unused python/pyspark/mllib/regression.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/classification.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/feature.py:30:1: F401 'pyspark.mllib.regression.LabeledPoint' imported but unused python/pyspark/mllib/tests/test_linalg.py:18:1: F401 'sys' imported but unused python/pyspark/mllib/tests/test_linalg.py:642:5: F401 'pyspark.mllib.tests.test_linalg.*' imported but unused python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.random' imported but unused python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.exp' imported but unused python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused python/pyspark/mllib/tests/test_feature.py:185:5: F401 'pyspark.mllib.tests.test_feature.*' imported but unused python/pyspark/mllib/tests/test_util.py:97:5: F401 'pyspark.mllib.tests.test_util.*' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg._convert_to_vector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseMatrix' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseMatrix' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.MatrixUDT' imported but unused python/pyspark/mllib/tests/test_stat.py:181:5: F401 'pyspark.mllib.tests.test_stat.*' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.time' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.sleep' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:470:5: F401 'pyspark.mllib.tests.test_streaming_algorithms.*' imported but unused python/pyspark/mllib/tests/test_algorithms.py:295:5: F401 'pyspark.mllib.tests.test_algorithms.*' imported but unused python/pyspark/tests/test_serializers.py:90:13: F401 'xmlrunner' imported but unused python/pyspark/tests/test_rdd.py:21:1: F401 'sys' imported but unused python/pyspark/tests/test_rdd.py:29:1: F401 'pyspark.resource.ResourceProfile' imported but unused python/pyspark/tests/test_rdd.py:885:5: F401 'pyspark.tests.test_rdd.*' imported but unused python/pyspark/tests/test_readwrite.py:19:1: F401 'sys' imported but unused python/pyspark/tests/test_readwrite.py:22:1: F401 'array.array' imported but unused python/pyspark/tests/test_readwrite.py:309:5: F401 'pyspark.tests.test_readwrite.*' imported but unused python/pyspark/tests/test_join.py:62:5: F401 'pyspark.tests.test_join.*' imported but unused python/pyspark/tests/test_taskcontext.py:19:1: F401 'shutil' imported but unused python/pyspark/tests/test_taskcontext.py:325:5: F401 'pyspark.tests.test_taskcontext.*' imported but unused python/pyspark/tests/test_conf.py:36:5: F401 'pyspark.tests.test_conf.*' imported but unused python/pyspark/tests/test_broadcast.py:148:5: F401 'pyspark.tests.test_broadcast.*' imported but unused python/pyspark/tests/test_daemon.py:76:5: F401 'pyspark.tests.test_daemon.*' imported but unused python/pyspark/tests/test_util.py:77:5: F401 'pyspark.tests.test_util.*' imported but unused python/pyspark/tests/test_pin_thread.py:19:1: F401 'random' imported but unused python/pyspark/tests/test_pin_thread.py:149:5: F401 'pyspark.tests.test_pin_thread.*' imported but unused python/pyspark/tests/test_worker.py:19:1: F401 'sys' imported but unused python/pyspark/tests/test_worker.py:26:5: F401 'resource' imported but unused python/pyspark/tests/test_worker.py:203:5: F401 'pyspark.tests.test_worker.*' imported but unused python/pyspark/tests/test_profiler.py:101:5: F401 'pyspark.tests.test_profiler.*' imported but unused python/pyspark/tests/test_shuffle.py:18:1: F401 'sys' imported but unused python/pyspark/tests/test_shuffle.py:171:5: F401 'pyspark.tests.test_shuffle.*' imported but unused python/pyspark/tests/test_rddbarrier.py:43:5: F401 'pyspark.tests.test_rddbarrier.*' imported but unused python/pyspark/tests/test_context.py:129:13: F401 'userlibrary.UserClass' imported but unused python/pyspark/tests/test_context.py:140:13: F401 'userlib.UserClass' imported but unused python/pyspark/tests/test_context.py:310:5: F401 'pyspark.tests.test_context.*' imported but unused python/pyspark/tests/test_appsubmit.py:241:5: F401 'pyspark.tests.test_appsubmit.*' imported but unused python/pyspark/streaming/dstream.py:18:1: F401 'sys' imported but unused python/pyspark/streaming/tests/test_dstream.py:27:1: F401 'pyspark.RDD' imported but unused python/pyspark/streaming/tests/test_dstream.py:647:5: F401 'pyspark.streaming.tests.test_dstream.*' imported but unused python/pyspark/streaming/tests/test_kinesis.py:83:5: F401 'pyspark.streaming.tests.test_kinesis.*' imported but unused python/pyspark/streaming/tests/test_listener.py:152:5: F401 'pyspark.streaming.tests.test_listener.*' imported but unused python/pyspark/streaming/tests/test_context.py:178:5: F401 'pyspark.streaming.tests.test_context.*' imported but unused python/pyspark/testing/utils.py:30:5: F401 'scipy.sparse' imported but unused python/pyspark/testing/utils.py:36:5: F401 'numpy as np' imported but unused python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._TreeEnsembleParams' imported but unused python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._HasVarianceImpurity' imported but unused python/pyspark/ml/regression.py:29:1: F401 'pyspark.ml.wrapper.JavaParams' imported but unused python/pyspark/ml/util.py:19:1: F401 'sys' imported but unused python/pyspark/ml/__init__.py:25:1: F401 'pyspark.ml.pipeline' imported but unused python/pyspark/ml/pipeline.py:18:1: F401 'sys' imported but unused python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.DenseMatrix' imported but unused python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.Vectors' imported but unused python/pyspark/ml/tests/test_training_summary.py:18:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_training_summary.py:364:5: F401 'pyspark.ml.tests.test_training_summary.*' imported but unused python/pyspark/ml/tests/test_linalg.py:381:5: F401 'pyspark.ml.tests.test_linalg.*' imported but unused python/pyspark/ml/tests/test_tuning.py:427:9: F401 'pyspark.sql.functions as F' imported but unused python/pyspark/ml/tests/test_tuning.py:757:5: F401 'pyspark.ml.tests.test_tuning.*' imported but unused python/pyspark/ml/tests/test_wrapper.py:120:5: F401 'pyspark.ml.tests.test_wrapper.*' imported but unused python/pyspark/ml/tests/test_feature.py:19:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_feature.py:304:5: F401 'pyspark.ml.tests.test_feature.*' imported but unused python/pyspark/ml/tests/test_image.py:19:1: F401 'py4j' imported but unused python/pyspark/ml/tests/test_image.py:22:1: F401 'pyspark.testing.mlutils.PySparkTestCase' imported but unused python/pyspark/ml/tests/test_image.py:71:5: F401 'pyspark.ml.tests.test_image.*' imported but unused python/pyspark/ml/tests/test_persistence.py:456:5: F401 'pyspark.ml.tests.test_persistence.*' imported but unused python/pyspark/ml/tests/test_evaluation.py:56:5: F401 'pyspark.ml.tests.test_evaluation.*' imported but unused python/pyspark/ml/tests/test_stat.py:43:5: F401 'pyspark.ml.tests.test_stat.*' imported but unused python/pyspark/ml/tests/test_base.py:70:5: F401 'pyspark.ml.tests.test_base.*' imported but unused python/pyspark/ml/tests/test_param.py:20:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_param.py:375:5: F401 'pyspark.ml.tests.test_param.*' imported but unused python/pyspark/ml/tests/test_pipeline.py:62:5: F401 'pyspark.ml.tests.test_pipeline.*' imported but unused python/pyspark/ml/tests/test_algorithms.py:333:5: F401 'pyspark.ml.tests.test_algorithms.*' imported but unused python/pyspark/ml/param/__init__.py:18:1: F401 'sys' imported but unused python/pyspark/resource/tests/test_resources.py:17:1: F401 'random' imported but unused python/pyspark/resource/tests/test_resources.py:20:1: F401 'pyspark.resource.ResourceProfile' imported but unused python/pyspark/resource/tests/test_resources.py:75:5: F401 'pyspark.resource.tests.test_resources.*' imported but unused python/pyspark/sql/functions.py:32:1: F401 'pyspark.sql.udf.UserDefinedFunction' imported but unused python/pyspark/sql/functions.py:34:1: F401 'pyspark.sql.pandas.functions.pandas_udf' imported but unused python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.Row' imported but unused python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.StringType' imported but unused python/pyspark/sql/readwriter.py:1084:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.IntegerType' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.Row' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.StringType' imported but unused python/pyspark/sql/context.py:27:1: F401 'pyspark.sql.udf.UDFRegistration' imported but unused python/pyspark/sql/streaming.py:1212:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/tests/test_utils.py:55:5: F401 'pyspark.sql.tests.test_utils.*' imported but unused python/pyspark/sql/tests/test_pandas_map.py:18:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.pandas_udf' imported but unused python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused python/pyspark/sql/tests/test_pandas_map.py:119:5: F401 'pyspark.sql.tests.test_pandas_map.*' imported but unused python/pyspark/sql/tests/test_catalog.py:193:5: F401 'pyspark.sql.tests.test_catalog.*' imported but unused python/pyspark/sql/tests/test_group.py:39:5: F401 'pyspark.sql.tests.test_group.*' imported but unused python/pyspark/sql/tests/test_session.py:361:5: F401 'pyspark.sql.tests.test_session.*' imported but unused python/pyspark/sql/tests/test_conf.py:49:5: F401 'pyspark.sql.tests.test_conf.*' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:19:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.sum' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:29:5: F401 'pandas.util.testing.assert_series_equal' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:32:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:248:5: F401 'pyspark.sql.tests.test_pandas_cogrouped_map.*' imported but unused python/pyspark/sql/tests/test_udf.py:24:1: F401 'py4j' imported but unused python/pyspark/sql/tests/test_pandas_udf_typehints.py:246:5: F401 'pyspark.sql.tests.test_pandas_udf_typehints.*' imported but unused python/pyspark/sql/tests/test_functions.py:19:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_functions.py:362:9: F401 'pyspark.sql.functions.exists' imported but unused python/pyspark/sql/tests/test_functions.py:387:5: F401 'pyspark.sql.tests.test_functions.*' imported but unused python/pyspark/sql/tests/test_pandas_udf_scalar.py:21:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_udf_scalar.py:45:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_udf_window.py:355:5: F401 'pyspark.sql.tests.test_pandas_udf_window.*' imported but unused python/pyspark/sql/tests/test_arrow.py:38:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_grouped_map.py:20:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_grouped_map.py:38:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_dataframe.py:382:9: F401 'pyspark.sql.DataFrame' imported but unused python/pyspark/sql/avro/functions.py:125:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/pandas/functions.py:19:1: F401 'sys' imported but unused ``` After: ``` fokkodriesprongFan spark % flake8 python | grep -i "imported but unused" fokkodriesprongFan spark % ``` ### What changes were proposed in this pull request? Removing unused imports from the Python files to keep everything nice and tidy. ### Why are the changes needed? Cleaning up of the imports that aren't used, and suppressing the imports that are used as references to other modules, preserving backward compatibility. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Adding the rule to the existing Flake8 checks. Closes #29121 from Fokko/SPARK-32319. Authored-by: Fokko Driesprong <fokko@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
432 lines
15 KiB
Python
432 lines
15 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from pyspark import keyword_only
|
|
from pyspark.ml.base import Estimator, Model, Transformer
|
|
from pyspark.ml.param import Param, Params
|
|
from pyspark.ml.util import *
|
|
from pyspark.ml.wrapper import JavaParams, JavaWrapper
|
|
from pyspark.ml.common import inherit_doc, _java2py, _py2java
|
|
|
|
|
|
@inherit_doc
|
|
class Pipeline(Estimator, MLReadable, MLWritable):
|
|
"""
|
|
A simple pipeline, which acts as an estimator. A Pipeline consists
|
|
of a sequence of stages, each of which is either an
|
|
:py:class:`Estimator` or a :py:class:`Transformer`. When
|
|
:py:meth:`Pipeline.fit` is called, the stages are executed in
|
|
order. If a stage is an :py:class:`Estimator`, its
|
|
:py:meth:`Estimator.fit` method will be called on the input
|
|
dataset to fit a model. Then the model, which is a transformer,
|
|
will be used to transform the dataset as the input to the next
|
|
stage. If a stage is a :py:class:`Transformer`, its
|
|
:py:meth:`Transformer.transform` method will be called to produce
|
|
the dataset for the next stage. The fitted model from a
|
|
:py:class:`Pipeline` is a :py:class:`PipelineModel`, which
|
|
consists of fitted models and transformers, corresponding to the
|
|
pipeline stages. If stages is an empty list, the pipeline acts as an
|
|
identity transformer.
|
|
|
|
.. versionadded:: 1.3.0
|
|
"""
|
|
|
|
stages = Param(Params._dummy(), "stages", "a list of pipeline stages")
|
|
|
|
@keyword_only
|
|
def __init__(self, stages=None):
|
|
"""
|
|
__init__(self, stages=None)
|
|
"""
|
|
super(Pipeline, self).__init__()
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@since("1.3.0")
|
|
def setStages(self, value):
|
|
"""
|
|
Set pipeline stages.
|
|
|
|
:param value: a list of transformers or estimators
|
|
:return: the pipeline instance
|
|
"""
|
|
return self._set(stages=value)
|
|
|
|
@since("1.3.0")
|
|
def getStages(self):
|
|
"""
|
|
Get pipeline stages.
|
|
"""
|
|
return self.getOrDefault(self.stages)
|
|
|
|
@keyword_only
|
|
@since("1.3.0")
|
|
def setParams(self, stages=None):
|
|
"""
|
|
setParams(self, stages=None)
|
|
Sets params for Pipeline.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _fit(self, dataset):
|
|
stages = self.getStages()
|
|
for stage in stages:
|
|
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
|
|
raise TypeError(
|
|
"Cannot recognize a pipeline stage of type %s." % type(stage))
|
|
indexOfLastEstimator = -1
|
|
for i, stage in enumerate(stages):
|
|
if isinstance(stage, Estimator):
|
|
indexOfLastEstimator = i
|
|
transformers = []
|
|
for i, stage in enumerate(stages):
|
|
if i <= indexOfLastEstimator:
|
|
if isinstance(stage, Transformer):
|
|
transformers.append(stage)
|
|
dataset = stage.transform(dataset)
|
|
else: # must be an Estimator
|
|
model = stage.fit(dataset)
|
|
transformers.append(model)
|
|
if i < indexOfLastEstimator:
|
|
dataset = model.transform(dataset)
|
|
else:
|
|
transformers.append(stage)
|
|
return PipelineModel(transformers)
|
|
|
|
@since("1.4.0")
|
|
def copy(self, extra=None):
|
|
"""
|
|
Creates a copy of this instance.
|
|
|
|
:param extra: extra parameters
|
|
:returns: new instance
|
|
"""
|
|
if extra is None:
|
|
extra = dict()
|
|
that = Params.copy(self, extra)
|
|
stages = [stage.copy(extra) for stage in that.getStages()]
|
|
return that.setStages(stages)
|
|
|
|
@since("2.0.0")
|
|
def write(self):
|
|
"""Returns an MLWriter instance for this ML instance."""
|
|
allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
|
|
if allStagesAreJava:
|
|
return JavaMLWriter(self)
|
|
return PipelineWriter(self)
|
|
|
|
@classmethod
|
|
@since("2.0.0")
|
|
def read(cls):
|
|
"""Returns an MLReader instance for this class."""
|
|
return PipelineReader(cls)
|
|
|
|
@classmethod
|
|
def _from_java(cls, java_stage):
|
|
"""
|
|
Given a Java Pipeline, create and return a Python wrapper of it.
|
|
Used for ML persistence.
|
|
"""
|
|
# Create a new instance of this stage.
|
|
py_stage = cls()
|
|
# Load information from java_stage to the instance.
|
|
py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
|
|
py_stage.setStages(py_stages)
|
|
py_stage._resetUid(java_stage.uid())
|
|
return py_stage
|
|
|
|
def _to_java(self):
|
|
"""
|
|
Transfer this instance to a Java Pipeline. Used for ML persistence.
|
|
|
|
:return: Java object equivalent to this instance.
|
|
"""
|
|
|
|
gateway = SparkContext._gateway
|
|
cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
|
|
java_stages = gateway.new_array(cls, len(self.getStages()))
|
|
for idx, stage in enumerate(self.getStages()):
|
|
java_stages[idx] = stage._to_java()
|
|
|
|
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
|
|
_java_obj.setStages(java_stages)
|
|
|
|
return _java_obj
|
|
|
|
def _make_java_param_pair(self, param, value):
|
|
"""
|
|
Makes a Java param pair.
|
|
"""
|
|
sc = SparkContext._active_spark_context
|
|
param = self._resolveParam(param)
|
|
java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
|
|
if isinstance(value, Params) and hasattr(value, "_to_java"):
|
|
# Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which
|
|
# implements `_to_java` method (such as OneVsRest, Pipeline object) to java object.
|
|
# used in the case of an estimator having another estimator as a parameter
|
|
# the reason why this is not in _py2java in common.py is that importing
|
|
# Estimator and Model in common.py results in a circular import with inherit_doc
|
|
java_value = value._to_java()
|
|
else:
|
|
java_value = _py2java(sc, value)
|
|
return java_param.w(java_value)
|
|
|
|
def _transfer_param_map_to_java(self, pyParamMap):
|
|
"""
|
|
Transforms a Python ParamMap into a Java ParamMap.
|
|
"""
|
|
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
|
|
for param in self.params:
|
|
if param in pyParamMap:
|
|
pair = self._make_java_param_pair(param, pyParamMap[param])
|
|
paramMap.put([pair])
|
|
return paramMap
|
|
|
|
def _transfer_param_map_from_java(self, javaParamMap):
|
|
"""
|
|
Transforms a Java ParamMap into a Python ParamMap.
|
|
"""
|
|
sc = SparkContext._active_spark_context
|
|
paramMap = dict()
|
|
for pair in javaParamMap.toList():
|
|
param = pair.param()
|
|
if self.hasParam(str(param.name())):
|
|
java_obj = pair.value()
|
|
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj):
|
|
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
|
|
# and Estimator/Transformer class which implements `_from_java` static method
|
|
# (such as OneVsRest, Pipeline class).
|
|
py_obj = JavaParams._from_java(java_obj)
|
|
else:
|
|
py_obj = _java2py(sc, java_obj)
|
|
paramMap[self.getParam(param.name())] = py_obj
|
|
return paramMap
|
|
|
|
|
|
@inherit_doc
|
|
class PipelineWriter(MLWriter):
|
|
"""
|
|
(Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
|
|
"""
|
|
|
|
def __init__(self, instance):
|
|
super(PipelineWriter, self).__init__()
|
|
self.instance = instance
|
|
|
|
def saveImpl(self, path):
|
|
stages = self.instance.getStages()
|
|
PipelineSharedReadWrite.validateStages(stages)
|
|
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
|
|
|
|
|
|
@inherit_doc
|
|
class PipelineReader(MLReader):
|
|
"""
|
|
(Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
|
|
"""
|
|
|
|
def __init__(self, cls):
|
|
super(PipelineReader, self).__init__()
|
|
self.cls = cls
|
|
|
|
def load(self, path):
|
|
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
|
if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
|
|
return JavaMLReader(self.cls).load(path)
|
|
else:
|
|
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
|
|
return Pipeline(stages=stages)._resetUid(uid)
|
|
|
|
|
|
@inherit_doc
|
|
class PipelineModelWriter(MLWriter):
|
|
"""
|
|
(Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
|
|
"""
|
|
|
|
def __init__(self, instance):
|
|
super(PipelineModelWriter, self).__init__()
|
|
self.instance = instance
|
|
|
|
def saveImpl(self, path):
|
|
stages = self.instance.stages
|
|
PipelineSharedReadWrite.validateStages(stages)
|
|
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
|
|
|
|
|
|
@inherit_doc
|
|
class PipelineModelReader(MLReader):
|
|
"""
|
|
(Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
|
|
"""
|
|
|
|
def __init__(self, cls):
|
|
super(PipelineModelReader, self).__init__()
|
|
self.cls = cls
|
|
|
|
def load(self, path):
|
|
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
|
if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
|
|
return JavaMLReader(self.cls).load(path)
|
|
else:
|
|
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
|
|
return PipelineModel(stages=stages)._resetUid(uid)
|
|
|
|
|
|
@inherit_doc
|
|
class PipelineModel(Model, MLReadable, MLWritable):
|
|
"""
|
|
Represents a compiled pipeline with transformers and fitted models.
|
|
|
|
.. versionadded:: 1.3.0
|
|
"""
|
|
|
|
def __init__(self, stages):
|
|
super(PipelineModel, self).__init__()
|
|
self.stages = stages
|
|
|
|
def _transform(self, dataset):
|
|
for t in self.stages:
|
|
dataset = t.transform(dataset)
|
|
return dataset
|
|
|
|
@since("1.4.0")
|
|
def copy(self, extra=None):
|
|
"""
|
|
Creates a copy of this instance.
|
|
|
|
:param extra: extra parameters
|
|
:returns: new instance
|
|
"""
|
|
if extra is None:
|
|
extra = dict()
|
|
stages = [stage.copy(extra) for stage in self.stages]
|
|
return PipelineModel(stages)
|
|
|
|
@since("2.0.0")
|
|
def write(self):
|
|
"""Returns an MLWriter instance for this ML instance."""
|
|
allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
|
|
if allStagesAreJava:
|
|
return JavaMLWriter(self)
|
|
return PipelineModelWriter(self)
|
|
|
|
@classmethod
|
|
@since("2.0.0")
|
|
def read(cls):
|
|
"""Returns an MLReader instance for this class."""
|
|
return PipelineModelReader(cls)
|
|
|
|
@classmethod
|
|
def _from_java(cls, java_stage):
|
|
"""
|
|
Given a Java PipelineModel, create and return a Python wrapper of it.
|
|
Used for ML persistence.
|
|
"""
|
|
# Load information from java_stage to the instance.
|
|
py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
|
|
# Create a new instance of this stage.
|
|
py_stage = cls(py_stages)
|
|
py_stage._resetUid(java_stage.uid())
|
|
return py_stage
|
|
|
|
def _to_java(self):
|
|
"""
|
|
Transfer this instance to a Java PipelineModel. Used for ML persistence.
|
|
|
|
:return: Java object equivalent to this instance.
|
|
"""
|
|
|
|
gateway = SparkContext._gateway
|
|
cls = SparkContext._jvm.org.apache.spark.ml.Transformer
|
|
java_stages = gateway.new_array(cls, len(self.stages))
|
|
for idx, stage in enumerate(self.stages):
|
|
java_stages[idx] = stage._to_java()
|
|
|
|
_java_obj =\
|
|
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
|
|
|
|
return _java_obj
|
|
|
|
|
|
@inherit_doc
|
|
class PipelineSharedReadWrite():
|
|
"""
|
|
Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
|
|
:py:class:`Pipeline` and :py:class:`PipelineModel`
|
|
|
|
.. versionadded:: 2.3.0
|
|
"""
|
|
|
|
@staticmethod
|
|
def checkStagesForJava(stages):
|
|
return all(isinstance(stage, JavaMLWritable) for stage in stages)
|
|
|
|
@staticmethod
|
|
def validateStages(stages):
|
|
"""
|
|
Check that all stages are Writable
|
|
"""
|
|
for stage in stages:
|
|
if not isinstance(stage, MLWritable):
|
|
raise ValueError("Pipeline write will fail on this pipeline " +
|
|
"because stage %s of type %s is not MLWritable",
|
|
stage.uid, type(stage))
|
|
|
|
@staticmethod
|
|
def saveImpl(instance, stages, sc, path):
|
|
"""
|
|
Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
|
|
- save metadata to path/metadata
|
|
- save stages to stages/IDX_UID
|
|
"""
|
|
stageUids = [stage.uid for stage in stages]
|
|
jsonParams = {'stageUids': stageUids, 'language': 'Python'}
|
|
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
|
|
stagesDir = os.path.join(path, "stages")
|
|
for index, stage in enumerate(stages):
|
|
stage.write().save(PipelineSharedReadWrite
|
|
.getStagePath(stage.uid, index, len(stages), stagesDir))
|
|
|
|
@staticmethod
|
|
def load(metadata, sc, path):
|
|
"""
|
|
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
|
|
|
|
:return: (UID, list of stages)
|
|
"""
|
|
stagesDir = os.path.join(path, "stages")
|
|
stageUids = metadata['paramMap']['stageUids']
|
|
stages = []
|
|
for index, stageUid in enumerate(stageUids):
|
|
stagePath = \
|
|
PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
|
|
stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
|
|
stages.append(stage)
|
|
return (metadata['uid'], stages)
|
|
|
|
@staticmethod
|
|
def getStagePath(stageUid, stageIdx, numStages, stagesDir):
|
|
"""
|
|
Get path for saving the given stage.
|
|
"""
|
|
stageIdxDigits = len(str(numStages))
|
|
stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
|
|
stagePath = os.path.join(stagesDir, stageDir)
|
|
return stagePath
|