2015-01-28 20:14:23 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import time
|
2015-01-28 20:14:23 -05:00
|
|
|
import uuid
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
from pyspark import SparkContext, since
|
2016-06-13 22:59:53 -04:00
|
|
|
from pyspark.ml.common import inherit_doc
|
2017-08-07 20:03:20 -04:00
|
|
|
from pyspark.sql import SparkSession
|
2018-05-15 19:50:09 -04:00
|
|
|
from pyspark.util import VersionUtils
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
|
|
|
|
def _jvm():
|
|
|
|
"""
|
|
|
|
Returns the JVM view associated with SparkContext. Must be called
|
|
|
|
after SparkContext is initialized.
|
|
|
|
"""
|
|
|
|
jvm = SparkContext._jvm
|
|
|
|
if jvm:
|
|
|
|
return jvm
|
|
|
|
else:
|
|
|
|
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
|
|
|
|
class Identifiable(object):
|
|
|
|
"""
|
|
|
|
Object with a unique ID.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
2015-05-18 15:02:18 -04:00
|
|
|
#: A unique id for the object.
|
|
|
|
self.uid = self._randomUID()
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return self.uid
|
2015-05-18 15:02:18 -04:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _randomUID(cls):
|
|
|
|
"""
|
2020-07-13 22:22:44 -04:00
|
|
|
Generate a unique string id for the object. The default implementation
|
2015-05-18 15:02:18 -04:00
|
|
|
concatenates the class name, "_", and 12 random hex chars.
|
|
|
|
"""
|
2020-07-13 22:22:44 -04:00
|
|
|
return str(cls.__name__ + "_" + uuid.uuid4().hex[-12:])
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
2017-08-07 20:03:20 -04:00
|
|
|
class BaseReadWrite(object):
|
2016-01-29 12:22:24 -05:00
|
|
|
"""
|
2017-08-07 20:03:20 -04:00
|
|
|
Base class for MLWriter and MLReader. Stores information about the SparkContext
|
|
|
|
and SparkSession.
|
2016-01-29 12:22:24 -05:00
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
.. versionadded:: 2.3.0
|
2016-01-29 12:22:24 -05:00
|
|
|
"""
|
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
def __init__(self):
|
|
|
|
self._sparkSession = None
|
2016-03-22 15:11:23 -04:00
|
|
|
|
2016-11-26 08:28:41 -05:00
|
|
|
def session(self, sparkSession):
|
2017-08-07 20:03:20 -04:00
|
|
|
"""
|
|
|
|
Sets the Spark Session to use for saving/loading.
|
|
|
|
"""
|
|
|
|
self._sparkSession = sparkSession
|
|
|
|
return self
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sparkSession(self):
|
|
|
|
"""
|
|
|
|
Returns the user-specified Spark Session or the default.
|
|
|
|
"""
|
|
|
|
if self._sparkSession is None:
|
|
|
|
self._sparkSession = SparkSession.builder.getOrCreate()
|
|
|
|
return self._sparkSession
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sc(self):
|
|
|
|
"""
|
|
|
|
Returns the underlying `SparkContext`.
|
|
|
|
"""
|
|
|
|
return self.sparkSession.sparkContext
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class MLWriter(BaseReadWrite):
|
|
|
|
"""
|
|
|
|
Utility class that can save ML instances.
|
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(MLWriter, self).__init__()
|
|
|
|
self.shouldOverwrite = False
|
|
|
|
|
|
|
|
def _handleOverwrite(self, path):
|
|
|
|
from pyspark.ml.wrapper import JavaWrapper
|
|
|
|
|
|
|
|
_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
|
|
|
|
wrapper = JavaWrapper(_java_obj)
|
[SPARK-28776][ML] SparkML Writer gets hadoop conf from session state
<!--
Thanks for sending a pull request! Here are some tips for you:
1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
4. Be sure to keep the PR description updated to reflect all changes.
5. Please write your PR title to summarize what this PR proposes.
6. If possible, provide a concise example to reproduce the issue for a faster review.
-->
### What changes were proposed in this pull request?
SparkML writer gets hadoop conf from session state, instead of the spark context.
<!--
Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
2. If you fix some SQL features, you can provide some references of other DBMSes.
3. If there is design documentation, please add the link.
4. If there is a discussion in the mailing list, please add the link.
-->
### Why are the changes needed?
Allow for multiple sessions in the same context that have different hadoop configurations.
<!--
Please clarify why the changes are needed. For instance,
1. If you propose a new API, clarify the use case for a new API.
2. If you fix a bug, you can clarify why it is a bug.
-->
### Does this PR introduce any user-facing change?
<!--
If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
If no, write 'No'.
-->
No
### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why it was difficult to add.
-->
Tested in pyspark.ml.tests.test_persistence.PersistenceTest test_default_read_write
Closes #25505 from helenyugithub/SPARK-28776.
Authored-by: heleny <heleny@palantir.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
2019-08-22 10:27:05 -04:00
|
|
|
wrapper._call_java("handleOverwrite", path, True, self.sparkSession._jsparkSession)
|
2017-08-07 20:03:20 -04:00
|
|
|
|
|
|
|
def save(self, path):
|
|
|
|
"""Save the ML instance to the input path."""
|
|
|
|
if self.shouldOverwrite:
|
|
|
|
self._handleOverwrite(path)
|
|
|
|
self.saveImpl(path)
|
|
|
|
|
|
|
|
def saveImpl(self, path):
|
|
|
|
"""
|
|
|
|
save() handles overwriting and then calls this method. Subclasses should override this
|
|
|
|
method to implement the actual saving of the instance.
|
|
|
|
"""
|
2016-03-22 15:11:23 -04:00
|
|
|
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
|
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
def overwrite(self):
|
|
|
|
"""Overwrites if the output path already exists."""
|
|
|
|
self.shouldOverwrite = True
|
|
|
|
return self
|
|
|
|
|
2016-03-22 15:11:23 -04:00
|
|
|
|
2018-06-28 16:20:08 -04:00
|
|
|
@inherit_doc
|
|
|
|
class GeneralMLWriter(MLWriter):
|
|
|
|
"""
|
|
|
|
Utility class that can save ML instances in different formats.
|
|
|
|
|
|
|
|
.. versionadded:: 2.4.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def format(self, source):
|
|
|
|
"""
|
|
|
|
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
|
|
|
|
name for export).
|
|
|
|
"""
|
|
|
|
self.source = source
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
2016-03-22 15:11:23 -04:00
|
|
|
@inherit_doc
|
|
|
|
class JavaMLWriter(MLWriter):
|
|
|
|
"""
|
2016-04-13 17:08:57 -04:00
|
|
|
(Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
|
2016-03-22 15:11:23 -04:00
|
|
|
"""
|
|
|
|
|
2016-01-29 12:22:24 -05:00
|
|
|
def __init__(self, instance):
|
2016-03-22 15:11:23 -04:00
|
|
|
super(JavaMLWriter, self).__init__()
|
|
|
|
_java_obj = instance._to_java()
|
|
|
|
self._jwrite = _java_obj.write()
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
def save(self, path):
|
|
|
|
"""Save the ML instance to the input path."""
|
2020-07-13 22:22:44 -04:00
|
|
|
if not isinstance(path, str):
|
|
|
|
raise TypeError("path should be a string, got type %s" % type(path))
|
2016-01-29 12:22:24 -05:00
|
|
|
self._jwrite.save(path)
|
|
|
|
|
|
|
|
def overwrite(self):
|
|
|
|
"""Overwrites if the output path already exists."""
|
|
|
|
self._jwrite.overwrite()
|
|
|
|
return self
|
|
|
|
|
2018-04-16 12:31:24 -04:00
|
|
|
def option(self, key, value):
|
|
|
|
self._jwrite.option(key, value)
|
|
|
|
return self
|
|
|
|
|
2016-11-26 08:28:41 -05:00
|
|
|
def session(self, sparkSession):
|
|
|
|
"""Sets the Spark Session to use for saving."""
|
|
|
|
self._jwrite.session(sparkSession._jsparkSession)
|
|
|
|
return self
|
|
|
|
|
2016-01-29 12:22:24 -05:00
|
|
|
|
2018-06-28 16:20:08 -04:00
|
|
|
@inherit_doc
|
|
|
|
class GeneralJavaMLWriter(JavaMLWriter):
|
|
|
|
"""
|
|
|
|
(Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, instance):
|
|
|
|
super(GeneralJavaMLWriter, self).__init__(instance)
|
|
|
|
|
|
|
|
def format(self, source):
|
|
|
|
"""
|
|
|
|
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
|
|
|
|
name for export).
|
|
|
|
"""
|
|
|
|
self._jwrite.format(source)
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
2016-01-29 12:22:24 -05:00
|
|
|
@inherit_doc
|
|
|
|
class MLWritable(object):
|
|
|
|
"""
|
2016-03-22 15:11:23 -04:00
|
|
|
Mixin for ML instances that provide :py:class:`MLWriter`.
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def write(self):
|
2016-04-25 14:02:32 -04:00
|
|
|
"""Returns an MLWriter instance for this ML instance."""
|
2016-03-22 15:11:23 -04:00
|
|
|
raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
def save(self, path):
|
2017-08-07 20:03:20 -04:00
|
|
|
"""Save this ML instance to the given path, a shortcut of 'write().save(path)'."""
|
2016-04-26 15:00:57 -04:00
|
|
|
self.write().save(path)
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
2016-03-22 15:11:23 -04:00
|
|
|
class JavaMLWritable(MLWritable):
|
|
|
|
"""
|
|
|
|
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def write(self):
|
2016-04-25 14:02:32 -04:00
|
|
|
"""Returns an MLWriter instance for this ML instance."""
|
2016-03-22 15:11:23 -04:00
|
|
|
return JavaMLWriter(self)
|
2018-06-28 16:20:08 -04:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class GeneralJavaMLWritable(JavaMLWritable):
|
|
|
|
"""
|
|
|
|
(Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def write(self):
|
|
|
|
"""Returns an GeneralMLWriter instance for this ML instance."""
|
|
|
|
return GeneralJavaMLWriter(self)
|
2016-03-22 15:11:23 -04:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
2017-08-07 20:03:20 -04:00
|
|
|
class MLReader(BaseReadWrite):
|
2016-01-29 12:22:24 -05:00
|
|
|
"""
|
2016-03-22 15:11:23 -04:00
|
|
|
Utility class that can load ML instances.
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
def __init__(self):
|
|
|
|
super(MLReader, self).__init__()
|
|
|
|
|
2016-03-22 15:11:23 -04:00
|
|
|
def load(self, path):
|
|
|
|
"""Load the ML instance from the input path."""
|
|
|
|
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaMLReader(MLReader):
|
|
|
|
"""
|
2016-04-13 17:08:57 -04:00
|
|
|
(Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
|
2016-03-22 15:11:23 -04:00
|
|
|
"""
|
|
|
|
|
2016-01-29 12:22:24 -05:00
|
|
|
def __init__(self, clazz):
|
2017-08-07 20:03:20 -04:00
|
|
|
super(JavaMLReader, self).__init__()
|
2016-01-29 12:22:24 -05:00
|
|
|
self._clazz = clazz
|
|
|
|
self._jread = self._load_java_obj(clazz).read()
|
|
|
|
|
|
|
|
def load(self, path):
|
|
|
|
"""Load the ML instance from the input path."""
|
2020-07-13 22:22:44 -04:00
|
|
|
if not isinstance(path, str):
|
|
|
|
raise TypeError("path should be a string, got type %s" % type(path))
|
2016-01-29 12:22:24 -05:00
|
|
|
java_obj = self._jread.load(path)
|
2016-03-22 15:11:23 -04:00
|
|
|
if not hasattr(self._clazz, "_from_java"):
|
|
|
|
raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
|
|
|
|
% self._clazz)
|
|
|
|
return self._clazz._from_java(java_obj)
|
2016-01-29 12:22:24 -05:00
|
|
|
|
2016-11-26 08:28:41 -05:00
|
|
|
def session(self, sparkSession):
|
|
|
|
"""Sets the Spark Session to use for loading."""
|
|
|
|
self._jread.session(sparkSession._jsparkSession)
|
|
|
|
return self
|
|
|
|
|
2016-01-29 12:22:24 -05:00
|
|
|
@classmethod
|
|
|
|
def _java_loader_class(cls, clazz):
|
|
|
|
"""
|
|
|
|
Returns the full class name of the Java ML instance. The default
|
|
|
|
implementation replaces "pyspark" by "org.apache.spark" in
|
|
|
|
the Python full class name.
|
|
|
|
"""
|
|
|
|
java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
|
2016-03-16 16:49:40 -04:00
|
|
|
if clazz.__name__ in ("Pipeline", "PipelineModel"):
|
|
|
|
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
|
|
|
|
java_package = ".".join(java_package.split(".")[0:-1])
|
2016-03-22 15:11:23 -04:00
|
|
|
return java_package + "." + clazz.__name__
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _load_java_obj(cls, clazz):
|
|
|
|
"""Load the peer Java object of the ML instance."""
|
|
|
|
java_class = cls._java_loader_class(clazz)
|
|
|
|
java_obj = _jvm()
|
|
|
|
for name in java_class.split("."):
|
|
|
|
java_obj = getattr(java_obj, name)
|
|
|
|
return java_obj
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class MLReadable(object):
|
|
|
|
"""
|
2016-03-22 15:11:23 -04:00
|
|
|
Mixin for instances that provide :py:class:`MLReader`.
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def read(cls):
|
2016-04-25 14:02:32 -04:00
|
|
|
"""Returns an MLReader instance for this class."""
|
2016-03-22 15:11:23 -04:00
|
|
|
raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
|
2016-01-29 12:22:24 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def load(cls, path):
|
|
|
|
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
|
|
|
|
return cls.read().load(path)
|
2016-03-22 15:11:23 -04:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaMLReadable(MLReadable):
|
|
|
|
"""
|
|
|
|
(Private) Mixin for instances that provide JavaMLReader.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def read(cls):
|
2016-04-25 14:02:32 -04:00
|
|
|
"""Returns an MLReader instance for this class."""
|
2016-03-22 15:11:23 -04:00
|
|
|
return JavaMLReader(cls)
|
2016-08-22 06:21:22 -04:00
|
|
|
|
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
@inherit_doc
|
|
|
|
class DefaultParamsWritable(MLWritable):
|
|
|
|
"""
|
|
|
|
Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params`
|
|
|
|
class stores all data as :py:class:`Param` values, then extending this trait will provide
|
|
|
|
a default implementation of writing saved instances of the class.
|
|
|
|
This only handles simple :py:class:`Param` types; e.g., it will not handle
|
|
|
|
:py:class:`Dataset`. See :py:class:`DefaultParamsReadable`, the counterpart to this trait.
|
|
|
|
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def write(self):
|
|
|
|
"""Returns a DefaultParamsWriter instance for this class."""
|
|
|
|
from pyspark.ml.param import Params
|
|
|
|
|
|
|
|
if isinstance(self, Params):
|
|
|
|
return DefaultParamsWriter(self)
|
|
|
|
else:
|
|
|
|
raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " +
|
|
|
|
" extend Params.", type(self))
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class DefaultParamsWriter(MLWriter):
|
|
|
|
"""
|
|
|
|
Specialization of :py:class:`MLWriter` for :py:class:`Params` types
|
|
|
|
|
|
|
|
Class for writing Estimators and Transformers whose parameters are JSON-serializable.
|
|
|
|
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, instance):
|
|
|
|
super(DefaultParamsWriter, self).__init__()
|
|
|
|
self.instance = instance
|
|
|
|
|
|
|
|
def saveImpl(self, path):
|
|
|
|
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
|
|
|
|
"""
|
|
|
|
Saves metadata + Params to: path + "/metadata"
|
2020-02-18 02:46:45 -05:00
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
- class
|
|
|
|
- timestamp
|
|
|
|
- sparkVersion
|
|
|
|
- uid
|
|
|
|
- paramMap
|
2018-05-15 19:50:09 -04:00
|
|
|
- defaultParamMap (since 2.4.0)
|
2017-08-07 20:03:20 -04:00
|
|
|
- (optionally, extra metadata)
|
2020-02-18 02:46:45 -05:00
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
extraMetadata : dict, optional
|
|
|
|
Extra metadata to be saved at same level as uid, paramMap, etc.
|
|
|
|
paramMap : dict, optional
|
|
|
|
If given, this is saved in the "paramMap" field.
|
2017-08-07 20:03:20 -04:00
|
|
|
"""
|
|
|
|
metadataPath = os.path.join(path, "metadata")
|
|
|
|
metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
|
|
|
|
sc,
|
|
|
|
extraMetadata,
|
|
|
|
paramMap)
|
|
|
|
sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
|
|
|
|
"""
|
|
|
|
Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save.
|
|
|
|
This is useful for ensemble models which need to save metadata for many sub-models.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
See :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes.
|
2017-08-07 20:03:20 -04:00
|
|
|
"""
|
|
|
|
uid = instance.uid
|
|
|
|
cls = instance.__module__ + '.' + instance.__class__.__name__
|
2018-05-15 19:50:09 -04:00
|
|
|
|
|
|
|
# User-supplied param values
|
|
|
|
params = instance._paramMap
|
2017-08-07 20:03:20 -04:00
|
|
|
jsonParams = {}
|
|
|
|
if paramMap is not None:
|
|
|
|
jsonParams = paramMap
|
|
|
|
else:
|
|
|
|
for p in params:
|
|
|
|
jsonParams[p.name] = params[p]
|
2018-05-15 19:50:09 -04:00
|
|
|
|
|
|
|
# Default param values
|
|
|
|
jsonDefaultParams = {}
|
|
|
|
for p in instance._defaultParamMap:
|
|
|
|
jsonDefaultParams[p.name] = instance._defaultParamMap[p]
|
|
|
|
|
2020-07-13 22:22:44 -04:00
|
|
|
basicMetadata = {"class": cls, "timestamp": int(round(time.time() * 1000)),
|
2018-05-15 19:50:09 -04:00
|
|
|
"sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
|
|
|
|
"defaultParamMap": jsonDefaultParams}
|
2017-08-07 20:03:20 -04:00
|
|
|
if extraMetadata is not None:
|
|
|
|
basicMetadata.update(extraMetadata)
|
|
|
|
return json.dumps(basicMetadata, separators=[',', ':'])
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class DefaultParamsReadable(MLReadable):
|
|
|
|
"""
|
|
|
|
Helper trait for making simple :py:class:`Params` types readable.
|
|
|
|
If a :py:class:`Params` class stores all data as :py:class:`Param` values,
|
|
|
|
then extending this trait will provide a default implementation of reading saved
|
|
|
|
instances of the class. This only handles simple :py:class:`Param` types;
|
|
|
|
e.g., it will not handle :py:class:`Dataset`. See :py:class:`DefaultParamsWritable`,
|
|
|
|
the counterpart to this trait.
|
|
|
|
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def read(cls):
|
|
|
|
"""Returns a DefaultParamsReader instance for this class."""
|
|
|
|
return DefaultParamsReader(cls)
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class DefaultParamsReader(MLReader):
|
|
|
|
"""
|
|
|
|
Specialization of :py:class:`MLReader` for :py:class:`Params` types
|
|
|
|
|
|
|
|
Default :py:class:`MLReader` implementation for transformers and estimators that
|
|
|
|
contain basic (json-serializable) params and no data. This will not handle
|
|
|
|
more complex params or types with data (e.g., models with coefficients).
|
|
|
|
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, cls):
|
|
|
|
super(DefaultParamsReader, self).__init__()
|
|
|
|
self.cls = cls
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def __get_class(clazz):
|
|
|
|
"""
|
|
|
|
Loads Python class from its name.
|
|
|
|
"""
|
|
|
|
parts = clazz.split('.')
|
|
|
|
module = ".".join(parts[:-1])
|
|
|
|
m = __import__(module)
|
|
|
|
for comp in parts[1:]:
|
|
|
|
m = getattr(m, comp)
|
|
|
|
return m
|
|
|
|
|
|
|
|
def load(self, path):
|
|
|
|
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
|
|
|
|
py_type = DefaultParamsReader.__get_class(metadata['class'])
|
|
|
|
instance = py_type()
|
|
|
|
instance._resetUid(metadata['uid'])
|
|
|
|
DefaultParamsReader.getAndSetParams(instance, metadata)
|
|
|
|
return instance
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def loadMetadata(path, sc, expectedClassName=""):
|
|
|
|
"""
|
|
|
|
Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
path : str
|
|
|
|
sc : :py:class:`pyspark.SparkContext`
|
|
|
|
expectedClassName : str, optional
|
|
|
|
If non empty, this is checked against the loaded metadata.
|
2017-08-07 20:03:20 -04:00
|
|
|
"""
|
|
|
|
metadataPath = os.path.join(path, "metadata")
|
|
|
|
metadataStr = sc.textFile(metadataPath, 1).first()
|
|
|
|
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
|
|
|
|
return loadedVals
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _parseMetaData(metadataStr, expectedClassName=""):
|
|
|
|
"""
|
|
|
|
Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`.
|
|
|
|
This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
metadataStr : str
|
|
|
|
JSON string of metadata
|
|
|
|
expectedClassName : str, optional
|
|
|
|
If non empty, this is checked against the loaded metadata.
|
2017-08-07 20:03:20 -04:00
|
|
|
"""
|
|
|
|
metadata = json.loads(metadataStr)
|
|
|
|
className = metadata['class']
|
|
|
|
if len(expectedClassName) > 0:
|
|
|
|
assert className == expectedClassName, "Error loading metadata: Expected " + \
|
|
|
|
"class name {} but found class name {}".format(expectedClassName, className)
|
|
|
|
return metadata
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def getAndSetParams(instance, metadata):
|
|
|
|
"""
|
|
|
|
Extract Params from metadata, and set them in the instance.
|
|
|
|
"""
|
2018-05-15 19:50:09 -04:00
|
|
|
# Set user-supplied param values
|
2017-08-07 20:03:20 -04:00
|
|
|
for paramName in metadata['paramMap']:
|
|
|
|
param = instance.getParam(paramName)
|
|
|
|
paramValue = metadata['paramMap'][paramName]
|
|
|
|
instance.set(param, paramValue)
|
|
|
|
|
2018-05-15 19:50:09 -04:00
|
|
|
# Set default param values
|
|
|
|
majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
|
|
|
|
major = majorAndMinorVersions[0]
|
|
|
|
minor = majorAndMinorVersions[1]
|
|
|
|
|
|
|
|
# For metadata file prior to Spark 2.4, there is no default section.
|
|
|
|
if major > 2 or (major == 2 and minor >= 4):
|
|
|
|
assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
|
|
|
|
"`defaultParamMap` section not found"
|
|
|
|
|
|
|
|
for paramName in metadata['defaultParamMap']:
|
|
|
|
paramValue = metadata['defaultParamMap'][paramName]
|
|
|
|
instance._setDefault(**{paramName: paramValue})
|
|
|
|
|
2017-08-07 20:03:20 -04:00
|
|
|
@staticmethod
|
|
|
|
def loadParamsInstance(path, sc):
|
|
|
|
"""
|
|
|
|
Load a :py:class:`Params` instance from the given path, and return it.
|
|
|
|
This assumes the instance inherits from :py:class:`MLReadable`.
|
|
|
|
"""
|
|
|
|
metadata = DefaultParamsReader.loadMetadata(path, sc)
|
|
|
|
pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
|
|
|
|
py_type = DefaultParamsReader.__get_class(pythonClassName)
|
|
|
|
instance = py_type.load(path)
|
|
|
|
return instance
|
2019-02-01 18:29:58 -05:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class HasTrainingSummary(object):
|
|
|
|
"""
|
|
|
|
Base class for models that provides Training summary.
|
2020-05-18 07:25:02 -04:00
|
|
|
|
2019-02-01 18:29:58 -05:00
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.1.0")
|
|
|
|
def hasSummary(self):
|
|
|
|
"""
|
|
|
|
Indicates whether a training summary exists for this model
|
|
|
|
instance.
|
|
|
|
"""
|
|
|
|
return self._call_java("hasSummary")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.1.0")
|
|
|
|
def summary(self):
|
|
|
|
"""
|
|
|
|
Gets summary of the model trained on the training set. An exception is thrown if
|
|
|
|
no summary exists.
|
|
|
|
"""
|
|
|
|
return (self._call_java("summary"))
|
[SPARK-33592] Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
### What changes were proposed in this pull request?
Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
When saving validator estimatorParamMaps, will check all nested stages in tuned estimator to get correct param parent.
Two typical cases to manually test:
~~~python
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression()
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
paramGrid = ParamGridBuilder() \
.addGrid(hashingTF.numFeatures, [10, 100]) \
.addGrid(lr.maxIter, [100, 200]) \
.build()
tvs = TrainValidationSplit(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator())
tvs.save(tvsPath)
loadedTvs = TrainValidationSplit.load(tvsPath)
# check `loadedTvs.getEstimatorParamMaps()` restored correctly.
~~~
~~~python
lr = LogisticRegression()
ova = OneVsRest(classifier=lr)
grid = ParamGridBuilder().addGrid(lr.maxIter, [100, 200]).build()
evaluator = MulticlassClassificationEvaluator()
tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
tvs.save(tvsPath)
loadedTvs = TrainValidationSplit.load(tvsPath)
# check `loadedTvs.getEstimatorParamMaps()` restored correctly.
~~~
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test.
Closes #30539 from WeichenXu123/fix_tuning_param_maps_io.
Authored-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
2020-11-30 20:36:42 -05:00
|
|
|
|
|
|
|
|
|
|
|
class MetaAlgorithmReadWrite:
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def isMetaEstimator(pyInstance):
|
|
|
|
from pyspark.ml import Estimator, Pipeline
|
|
|
|
from pyspark.ml.tuning import _ValidatorParams
|
|
|
|
from pyspark.ml.classification import OneVsRest
|
|
|
|
return isinstance(pyInstance, Pipeline) or isinstance(pyInstance, OneVsRest) or \
|
|
|
|
(isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams))
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def getAllNestedStages(pyInstance):
|
|
|
|
from pyspark.ml import Pipeline, PipelineModel
|
|
|
|
from pyspark.ml.tuning import _ValidatorParams
|
|
|
|
from pyspark.ml.classification import OneVsRest, OneVsRestModel
|
|
|
|
|
|
|
|
# TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel
|
|
|
|
# support pipelineModel property.
|
|
|
|
if isinstance(pyInstance, Pipeline):
|
|
|
|
pySubStages = pyInstance.getStages()
|
|
|
|
elif isinstance(pyInstance, PipelineModel):
|
|
|
|
pySubStages = pyInstance.stages
|
|
|
|
elif isinstance(pyInstance, _ValidatorParams):
|
|
|
|
raise ValueError('PySpark does not support nested validator.')
|
|
|
|
elif isinstance(pyInstance, OneVsRest):
|
|
|
|
pySubStages = [pyInstance.getClassifier()]
|
|
|
|
elif isinstance(pyInstance, OneVsRestModel):
|
|
|
|
pySubStages = [pyInstance.getClassifier()] + pyInstance.models
|
|
|
|
else:
|
|
|
|
pySubStages = []
|
|
|
|
|
|
|
|
nestedStages = []
|
|
|
|
for pySubStage in pySubStages:
|
|
|
|
nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))
|
|
|
|
|
|
|
|
return [pyInstance] + nestedStages
|