spark-instrumented-optimizer/python/pyspark/ml/util.py

592 lines
18 KiB
Python
Raw Normal View History

[SPARK-4586][MLLIB] Python API for ML pipeline and parameters This PR adds Python API for ML pipeline and parameters. The design doc can be found on the JIRA page. It includes transformers and an estimator to demo the simple text classification example code. TODO: - [x] handle parameters in LRModel - [x] unit tests - [x] missing some docs CC: davies jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4151 from mengxr/SPARK-4586 and squashes the following commits: 415268e [Xiangrui Meng] remove inherit_doc from __init__ edbd6fe [Xiangrui Meng] move Identifiable to ml.util 44c2405 [Xiangrui Meng] Merge pull request #2 from davies/ml dd1256b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 14ae7e2 [Davies Liu] fix docs 54ca7df [Davies Liu] fix tests 78638df [Davies Liu] Merge branch 'SPARK-4586' of github.com:mengxr/spark into ml fc59a02 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 1dca16a [Davies Liu] refactor 090b3a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into ml 0882513 [Xiangrui Meng] update doc style a4f4dbf [Xiangrui Meng] add unit test for LR 7521d1c [Xiangrui Meng] add unit tests to HashingTF and Tokenizer ba0ba1e [Xiangrui Meng] add unit tests for pipeline 0586c7b [Xiangrui Meng] add more comments to the example 5153cff [Xiangrui Meng] simplify java models 036ca04 [Xiangrui Meng] gen numFeatures 46fa147 [Xiangrui Meng] update mllib/pom.xml to include python files in the assembly 1dcc17e [Xiangrui Meng] update code gen and make param appear in the doc f66ba0c [Xiangrui Meng] make params a property d5efd34 [Xiangrui Meng] update doc conf and move embedded param map to instance attribute f4d0fe6 [Xiangrui Meng] use LabeledDocument and Document in example 05e3e40 [Xiangrui Meng] update example d3e8dbe [Xiangrui Meng] more docs optimize pipeline.fit impl 56de571 [Xiangrui Meng] fix style d0c5bb8 [Xiangrui Meng] a working copy bce72f4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 17ecfb9 [Xiangrui Meng] code gen for shared params d9ea77c [Xiangrui Meng] update doc c18dca1 [Xiangrui Meng] make the example working dadd84e [Xiangrui Meng] add base classes and docs a3015cf [Xiangrui Meng] add Estimator and Transformer 46eea43 [Xiangrui Meng] a pipeline in python 33b68e0 [Xiangrui Meng] a working LR
2015-01-28 20:14:23 -05:00
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import sys
import os
import time
[SPARK-4586][MLLIB] Python API for ML pipeline and parameters This PR adds Python API for ML pipeline and parameters. The design doc can be found on the JIRA page. It includes transformers and an estimator to demo the simple text classification example code. TODO: - [x] handle parameters in LRModel - [x] unit tests - [x] missing some docs CC: davies jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4151 from mengxr/SPARK-4586 and squashes the following commits: 415268e [Xiangrui Meng] remove inherit_doc from __init__ edbd6fe [Xiangrui Meng] move Identifiable to ml.util 44c2405 [Xiangrui Meng] Merge pull request #2 from davies/ml dd1256b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 14ae7e2 [Davies Liu] fix docs 54ca7df [Davies Liu] fix tests 78638df [Davies Liu] Merge branch 'SPARK-4586' of github.com:mengxr/spark into ml fc59a02 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 1dca16a [Davies Liu] refactor 090b3a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into ml 0882513 [Xiangrui Meng] update doc style a4f4dbf [Xiangrui Meng] add unit test for LR 7521d1c [Xiangrui Meng] add unit tests to HashingTF and Tokenizer ba0ba1e [Xiangrui Meng] add unit tests for pipeline 0586c7b [Xiangrui Meng] add more comments to the example 5153cff [Xiangrui Meng] simplify java models 036ca04 [Xiangrui Meng] gen numFeatures 46fa147 [Xiangrui Meng] update mllib/pom.xml to include python files in the assembly 1dcc17e [Xiangrui Meng] update code gen and make param appear in the doc f66ba0c [Xiangrui Meng] make params a property d5efd34 [Xiangrui Meng] update doc conf and move embedded param map to instance attribute f4d0fe6 [Xiangrui Meng] use LabeledDocument and Document in example 05e3e40 [Xiangrui Meng] update example d3e8dbe [Xiangrui Meng] more docs optimize pipeline.fit impl 56de571 [Xiangrui Meng] fix style d0c5bb8 [Xiangrui Meng] a working copy bce72f4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 17ecfb9 [Xiangrui Meng] code gen for shared params d9ea77c [Xiangrui Meng] update doc c18dca1 [Xiangrui Meng] make the example working dadd84e [Xiangrui Meng] add base classes and docs a3015cf [Xiangrui Meng] add Estimator and Transformer 46eea43 [Xiangrui Meng] a pipeline in python 33b68e0 [Xiangrui Meng] a working LR
2015-01-28 20:14:23 -05:00
import uuid
import warnings
if sys.version > '3':
basestring = str
unicode = str
long = int
from pyspark import SparkContext, since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
[SPARK-4586][MLLIB] Python API for ML pipeline and parameters This PR adds Python API for ML pipeline and parameters. The design doc can be found on the JIRA page. It includes transformers and an estimator to demo the simple text classification example code. TODO: - [x] handle parameters in LRModel - [x] unit tests - [x] missing some docs CC: davies jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4151 from mengxr/SPARK-4586 and squashes the following commits: 415268e [Xiangrui Meng] remove inherit_doc from __init__ edbd6fe [Xiangrui Meng] move Identifiable to ml.util 44c2405 [Xiangrui Meng] Merge pull request #2 from davies/ml dd1256b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 14ae7e2 [Davies Liu] fix docs 54ca7df [Davies Liu] fix tests 78638df [Davies Liu] Merge branch 'SPARK-4586' of github.com:mengxr/spark into ml fc59a02 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 1dca16a [Davies Liu] refactor 090b3a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into ml 0882513 [Xiangrui Meng] update doc style a4f4dbf [Xiangrui Meng] add unit test for LR 7521d1c [Xiangrui Meng] add unit tests to HashingTF and Tokenizer ba0ba1e [Xiangrui Meng] add unit tests for pipeline 0586c7b [Xiangrui Meng] add more comments to the example 5153cff [Xiangrui Meng] simplify java models 036ca04 [Xiangrui Meng] gen numFeatures 46fa147 [Xiangrui Meng] update mllib/pom.xml to include python files in the assembly 1dcc17e [Xiangrui Meng] update code gen and make param appear in the doc f66ba0c [Xiangrui Meng] make params a property d5efd34 [Xiangrui Meng] update doc conf and move embedded param map to instance attribute f4d0fe6 [Xiangrui Meng] use LabeledDocument and Document in example 05e3e40 [Xiangrui Meng] update example d3e8dbe [Xiangrui Meng] more docs optimize pipeline.fit impl 56de571 [Xiangrui Meng] fix style d0c5bb8 [Xiangrui Meng] a working copy bce72f4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 17ecfb9 [Xiangrui Meng] code gen for shared params d9ea77c [Xiangrui Meng] update doc c18dca1 [Xiangrui Meng] make the example working dadd84e [Xiangrui Meng] add base classes and docs a3015cf [Xiangrui Meng] add Estimator and Transformer 46eea43 [Xiangrui Meng] a pipeline in python 33b68e0 [Xiangrui Meng] a working LR
2015-01-28 20:14:23 -05:00
class Identifiable(object):
"""
Object with a unique ID.
"""
def __init__(self):
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
2015-05-18 15:02:18 -04:00
#: A unique id for the object.
self.uid = self._randomUID()
[SPARK-4586][MLLIB] Python API for ML pipeline and parameters This PR adds Python API for ML pipeline and parameters. The design doc can be found on the JIRA page. It includes transformers and an estimator to demo the simple text classification example code. TODO: - [x] handle parameters in LRModel - [x] unit tests - [x] missing some docs CC: davies jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4151 from mengxr/SPARK-4586 and squashes the following commits: 415268e [Xiangrui Meng] remove inherit_doc from __init__ edbd6fe [Xiangrui Meng] move Identifiable to ml.util 44c2405 [Xiangrui Meng] Merge pull request #2 from davies/ml dd1256b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 14ae7e2 [Davies Liu] fix docs 54ca7df [Davies Liu] fix tests 78638df [Davies Liu] Merge branch 'SPARK-4586' of github.com:mengxr/spark into ml fc59a02 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 1dca16a [Davies Liu] refactor 090b3a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into ml 0882513 [Xiangrui Meng] update doc style a4f4dbf [Xiangrui Meng] add unit test for LR 7521d1c [Xiangrui Meng] add unit tests to HashingTF and Tokenizer ba0ba1e [Xiangrui Meng] add unit tests for pipeline 0586c7b [Xiangrui Meng] add more comments to the example 5153cff [Xiangrui Meng] simplify java models 036ca04 [Xiangrui Meng] gen numFeatures 46fa147 [Xiangrui Meng] update mllib/pom.xml to include python files in the assembly 1dcc17e [Xiangrui Meng] update code gen and make param appear in the doc f66ba0c [Xiangrui Meng] make params a property d5efd34 [Xiangrui Meng] update doc conf and move embedded param map to instance attribute f4d0fe6 [Xiangrui Meng] use LabeledDocument and Document in example 05e3e40 [Xiangrui Meng] update example d3e8dbe [Xiangrui Meng] more docs optimize pipeline.fit impl 56de571 [Xiangrui Meng] fix style d0c5bb8 [Xiangrui Meng] a working copy bce72f4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 17ecfb9 [Xiangrui Meng] code gen for shared params d9ea77c [Xiangrui Meng] update doc c18dca1 [Xiangrui Meng] make the example working dadd84e [Xiangrui Meng] add base classes and docs a3015cf [Xiangrui Meng] add Estimator and Transformer 46eea43 [Xiangrui Meng] a pipeline in python 33b68e0 [Xiangrui Meng] a working LR
2015-01-28 20:14:23 -05:00
def __repr__(self):
return self.uid
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
2015-05-18 15:02:18 -04:00
@classmethod
def _randomUID(cls):
"""
Generate a unique unicode id for the object. The default implementation
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
2015-05-18 15:02:18 -04:00
concatenates the class name, "_", and 12 random hex chars.
"""
return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:])
@inherit_doc
class BaseReadWrite(object):
"""
Base class for MLWriter and MLReader. Stores information about the SparkContext
and SparkSession.
.. versionadded:: 2.3.0
"""
def __init__(self):
self._sparkSession = None
def session(self, sparkSession):
"""
Sets the Spark Session to use for saving/loading.
"""
self._sparkSession = sparkSession
return self
@property
def sparkSession(self):
"""
Returns the user-specified Spark Session or the default.
"""
if self._sparkSession is None:
self._sparkSession = SparkSession.builder.getOrCreate()
return self._sparkSession
@property
def sc(self):
"""
Returns the underlying `SparkContext`.
"""
return self.sparkSession.sparkContext
@inherit_doc
class MLWriter(BaseReadWrite):
"""
Utility class that can save ML instances.
.. versionadded:: 2.0.0
"""
def __init__(self):
super(MLWriter, self).__init__()
self.shouldOverwrite = False
def _handleOverwrite(self, path):
from pyspark.ml.wrapper import JavaWrapper
_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
wrapper = JavaWrapper(_java_obj)
[SPARK-28776][ML] SparkML Writer gets hadoop conf from session state <!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html 2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'. 4. Be sure to keep the PR description updated to reflect all changes. 5. Please write your PR title to summarize what this PR proposes. 6. If possible, provide a concise example to reproduce the issue for a faster review. --> ### What changes were proposed in this pull request? SparkML writer gets hadoop conf from session state, instead of the spark context. <!-- Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below. 1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers. 2. If you fix some SQL features, you can provide some references of other DBMSes. 3. If there is design documentation, please add the link. 4. If there is a discussion in the mailing list, please add the link. --> ### Why are the changes needed? Allow for multiple sessions in the same context that have different hadoop configurations. <!-- Please clarify why the changes are needed. For instance, 1. If you propose a new API, clarify the use case for a new API. 2. If you fix a bug, you can clarify why it is a bug. --> ### Does this PR introduce any user-facing change? <!-- If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible. If no, write 'No'. --> No ### How was this patch tested? <!-- If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Tested in pyspark.ml.tests.test_persistence.PersistenceTest test_default_read_write Closes #25505 from helenyugithub/SPARK-28776. Authored-by: heleny <heleny@palantir.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
2019-08-22 10:27:05 -04:00
wrapper._call_java("handleOverwrite", path, True, self.sparkSession._jsparkSession)
def save(self, path):
"""Save the ML instance to the input path."""
if self.shouldOverwrite:
self._handleOverwrite(path)
self.saveImpl(path)
def saveImpl(self, path):
"""
save() handles overwriting and then calls this method. Subclasses should override this
method to implement the actual saving of the instance.
"""
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
def overwrite(self):
"""Overwrites if the output path already exists."""
self.shouldOverwrite = True
return self
@inherit_doc
class GeneralMLWriter(MLWriter):
"""
Utility class that can save ML instances in different formats.
.. versionadded:: 2.4.0
"""
def format(self, source):
"""
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
name for export).
"""
self.source = source
return self
@inherit_doc
class JavaMLWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
super(JavaMLWriter, self).__init__()
_java_obj = instance._to_java()
self._jwrite = _java_obj.write()
def save(self, path):
"""Save the ML instance to the input path."""
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
self._jwrite.save(path)
def overwrite(self):
"""Overwrites if the output path already exists."""
self._jwrite.overwrite()
return self
def option(self, key, value):
self._jwrite.option(key, value)
return self
def session(self, sparkSession):
"""Sets the Spark Session to use for saving."""
self._jwrite.session(sparkSession._jsparkSession)
return self
@inherit_doc
class GeneralJavaMLWriter(JavaMLWriter):
"""
(Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
super(GeneralJavaMLWriter, self).__init__(instance)
def format(self, source):
"""
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
name for export).
"""
self._jwrite.format(source)
return self
@inherit_doc
class MLWritable(object):
"""
Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0
"""
def write(self):
"""Returns an MLWriter instance for this ML instance."""
raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of 'write().save(path)'."""
self.write().save(path)
@inherit_doc
class JavaMLWritable(MLWritable):
"""
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
"""
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
@inherit_doc
class GeneralJavaMLWritable(JavaMLWritable):
"""
(Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
"""
def write(self):
"""Returns an GeneralMLWriter instance for this ML instance."""
return GeneralJavaMLWriter(self)
@inherit_doc
class MLReader(BaseReadWrite):
"""
Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
def __init__(self):
super(MLReader, self).__init__()
def load(self, path):
"""Load the ML instance from the input path."""
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
@inherit_doc
class JavaMLReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
"""
def __init__(self, clazz):
super(JavaMLReader, self).__init__()
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
def load(self, path):
"""Load the ML instance from the input path."""
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
if not hasattr(self._clazz, "_from_java"):
raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
% self._clazz)
return self._clazz._from_java(java_obj)
def session(self, sparkSession):
"""Sets the Spark Session to use for loading."""
self._jread.session(sparkSession._jsparkSession)
return self
@classmethod
def _java_loader_class(cls, clazz):
"""
Returns the full class name of the Java ML instance. The default
implementation replaces "pyspark" by "org.apache.spark" in
the Python full class name.
"""
java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
if clazz.__name__ in ("Pipeline", "PipelineModel"):
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
java_package = ".".join(java_package.split(".")[0:-1])
return java_package + "." + clazz.__name__
@classmethod
def _load_java_obj(cls, clazz):
"""Load the peer Java object of the ML instance."""
java_class = cls._java_loader_class(clazz)
java_obj = _jvm()
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj
@inherit_doc
class MLReadable(object):
"""
Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0
"""
@classmethod
def read(cls):
"""Returns an MLReader instance for this class."""
raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
@inherit_doc
class JavaMLReadable(MLReadable):
"""
(Private) Mixin for instances that provide JavaMLReader.
"""
@classmethod
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
@inherit_doc
class DefaultParamsWritable(MLWritable):
"""
.. note:: DeveloperApi
Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params`
class stores all data as :py:class:`Param` values, then extending this trait will provide
a default implementation of writing saved instances of the class.
This only handles simple :py:class:`Param` types; e.g., it will not handle
:py:class:`Dataset`. See :py:class:`DefaultParamsReadable`, the counterpart to this trait.
.. versionadded:: 2.3.0
"""
def write(self):
"""Returns a DefaultParamsWriter instance for this class."""
from pyspark.ml.param import Params
if isinstance(self, Params):
return DefaultParamsWriter(self)
else:
raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " +
" extend Params.", type(self))
@inherit_doc
class DefaultParamsWriter(MLWriter):
"""
.. note:: DeveloperApi
Specialization of :py:class:`MLWriter` for :py:class:`Params` types
Class for writing Estimators and Transformers whose parameters are JSON-serializable.
.. versionadded:: 2.3.0
"""
def __init__(self, instance):
super(DefaultParamsWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
@staticmethod
def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
"""
Saves metadata + Params to: path + "/metadata"
- class
- timestamp
- sparkVersion
- uid
- paramMap
- defaultParamMap (since 2.4.0)
- (optionally, extra metadata)
:param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
:param paramMap: If given, this is saved in the "paramMap" field.
"""
metadataPath = os.path.join(path, "metadata")
metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
sc,
extraMetadata,
paramMap)
sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
@staticmethod
def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
"""
Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save.
This is useful for ensemble models which need to save metadata for many sub-models.
.. note:: :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes.
"""
uid = instance.uid
cls = instance.__module__ + '.' + instance.__class__.__name__
# User-supplied param values
params = instance._paramMap
jsonParams = {}
if paramMap is not None:
jsonParams = paramMap
else:
for p in params:
jsonParams[p.name] = params[p]
# Default param values
jsonDefaultParams = {}
for p in instance._defaultParamMap:
jsonDefaultParams[p.name] = instance._defaultParamMap[p]
basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
"sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
"defaultParamMap": jsonDefaultParams}
if extraMetadata is not None:
basicMetadata.update(extraMetadata)
return json.dumps(basicMetadata, separators=[',', ':'])
@inherit_doc
class DefaultParamsReadable(MLReadable):
"""
.. note:: DeveloperApi
Helper trait for making simple :py:class:`Params` types readable.
If a :py:class:`Params` class stores all data as :py:class:`Param` values,
then extending this trait will provide a default implementation of reading saved
instances of the class. This only handles simple :py:class:`Param` types;
e.g., it will not handle :py:class:`Dataset`. See :py:class:`DefaultParamsWritable`,
the counterpart to this trait.
.. versionadded:: 2.3.0
"""
@classmethod
def read(cls):
"""Returns a DefaultParamsReader instance for this class."""
return DefaultParamsReader(cls)
@inherit_doc
class DefaultParamsReader(MLReader):
"""
.. note:: DeveloperApi
Specialization of :py:class:`MLReader` for :py:class:`Params` types
Default :py:class:`MLReader` implementation for transformers and estimators that
contain basic (json-serializable) params and no data. This will not handle
more complex params or types with data (e.g., models with coefficients).
.. versionadded:: 2.3.0
"""
def __init__(self, cls):
super(DefaultParamsReader, self).__init__()
self.cls = cls
@staticmethod
def __get_class(clazz):
"""
Loads Python class from its name.
"""
parts = clazz.split('.')
module = ".".join(parts[:-1])
m = __import__(module)
for comp in parts[1:]:
m = getattr(m, comp)
return m
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
py_type = DefaultParamsReader.__get_class(metadata['class'])
instance = py_type()
instance._resetUid(metadata['uid'])
DefaultParamsReader.getAndSetParams(instance, metadata)
return instance
@staticmethod
def loadMetadata(path, sc, expectedClassName=""):
"""
Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
:param expectedClassName: If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
metadataStr = sc.textFile(metadataPath, 1).first()
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
return loadedVals
@staticmethod
def _parseMetaData(metadataStr, expectedClassName=""):
"""
Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`.
This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`.
:param metadataStr: JSON string of metadata
:param expectedClassName: If non empty, this is checked against the loaded metadata.
"""
metadata = json.loads(metadataStr)
className = metadata['class']
if len(expectedClassName) > 0:
assert className == expectedClassName, "Error loading metadata: Expected " + \
"class name {} but found class name {}".format(expectedClassName, className)
return metadata
@staticmethod
def getAndSetParams(instance, metadata):
"""
Extract Params from metadata, and set them in the instance.
"""
# Set user-supplied param values
for paramName in metadata['paramMap']:
param = instance.getParam(paramName)
paramValue = metadata['paramMap'][paramName]
instance.set(param, paramValue)
# Set default param values
majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
major = majorAndMinorVersions[0]
minor = majorAndMinorVersions[1]
# For metadata file prior to Spark 2.4, there is no default section.
if major > 2 or (major == 2 and minor >= 4):
assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
"`defaultParamMap` section not found"
for paramName in metadata['defaultParamMap']:
paramValue = metadata['defaultParamMap'][paramName]
instance._setDefault(**{paramName: paramValue})
@staticmethod
def loadParamsInstance(path, sc):
"""
Load a :py:class:`Params` instance from the given path, and return it.
This assumes the instance inherits from :py:class:`MLReadable`.
"""
metadata = DefaultParamsReader.loadMetadata(path, sc)
pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
py_type = DefaultParamsReader.__get_class(pythonClassName)
instance = py_type.load(path)
return instance
@inherit_doc
class HasTrainingSummary(object):
"""
Base class for models that provides Training summary.
.. versionadded:: 3.0.0
"""
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
"""
Gets summary of the model trained on the training set. An exception is thrown if
no summary exists.
"""
return (self._call_java("summary"))