spark-instrumented-optimizer/python/pyspark/ml/util.py
heleny fb1f868d4f [SPARK-28776][ML] SparkML Writer gets hadoop conf from session state
<!--
Thanks for sending a pull request!  Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
  2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
  4. Be sure to keep the PR description updated to reflect all changes.
  5. Please write your PR title to summarize what this PR proposes.
  6. If possible, provide a concise example to reproduce the issue for a faster review.
-->

### What changes were proposed in this pull request?
SparkML writer gets hadoop conf from session state, instead of the spark context.
<!--
Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
  1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
  2. If you fix some SQL features, you can provide some references of other DBMSes.
  3. If there is design documentation, please add the link.
  4. If there is a discussion in the mailing list, please add the link.
-->

### Why are the changes needed?
Allow for multiple sessions in the same context that have different hadoop configurations.
<!--
Please clarify why the changes are needed. For instance,
  1. If you propose a new API, clarify the use case for a new API.
  2. If you fix a bug, you can clarify why it is a bug.
-->

### Does this PR introduce any user-facing change?
<!--
If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
If no, write 'No'.
-->
No

### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why it was difficult to add.
-->
Tested in pyspark.ml.tests.test_persistence.PersistenceTest test_default_read_write

Closes #25505 from helenyugithub/SPARK-28776.

Authored-by: heleny <heleny@palantir.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
2019-08-22 09:27:05 -05:00

608 lines
19 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import sys
import os
import time
import uuid
import warnings
if sys.version > '3':
basestring = str
unicode = str
long = int
from pyspark import SparkContext, since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
class Identifiable(object):
"""
Object with a unique ID.
"""
def __init__(self):
#: A unique id for the object.
self.uid = self._randomUID()
def __repr__(self):
return self.uid
@classmethod
def _randomUID(cls):
"""
Generate a unique unicode id for the object. The default implementation
concatenates the class name, "_", and 12 random hex chars.
"""
return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:])
@inherit_doc
class BaseReadWrite(object):
"""
Base class for MLWriter and MLReader. Stores information about the SparkContext
and SparkSession.
.. versionadded:: 2.3.0
"""
def __init__(self):
self._sparkSession = None
def session(self, sparkSession):
"""
Sets the Spark Session to use for saving/loading.
"""
self._sparkSession = sparkSession
return self
@property
def sparkSession(self):
"""
Returns the user-specified Spark Session or the default.
"""
if self._sparkSession is None:
self._sparkSession = SparkSession.builder.getOrCreate()
return self._sparkSession
@property
def sc(self):
"""
Returns the underlying `SparkContext`.
"""
return self.sparkSession.sparkContext
@inherit_doc
class MLWriter(BaseReadWrite):
"""
Utility class that can save ML instances.
.. versionadded:: 2.0.0
"""
def __init__(self):
super(MLWriter, self).__init__()
self.shouldOverwrite = False
def _handleOverwrite(self, path):
from pyspark.ml.wrapper import JavaWrapper
_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
wrapper = JavaWrapper(_java_obj)
wrapper._call_java("handleOverwrite", path, True, self.sparkSession._jsparkSession)
def save(self, path):
"""Save the ML instance to the input path."""
if self.shouldOverwrite:
self._handleOverwrite(path)
self.saveImpl(path)
def saveImpl(self, path):
"""
save() handles overwriting and then calls this method. Subclasses should override this
method to implement the actual saving of the instance.
"""
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
def overwrite(self):
"""Overwrites if the output path already exists."""
self.shouldOverwrite = True
return self
@inherit_doc
class GeneralMLWriter(MLWriter):
"""
Utility class that can save ML instances in different formats.
.. versionadded:: 2.4.0
"""
def format(self, source):
"""
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
name for export).
"""
self.source = source
return self
@inherit_doc
class JavaMLWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
super(JavaMLWriter, self).__init__()
_java_obj = instance._to_java()
self._jwrite = _java_obj.write()
def save(self, path):
"""Save the ML instance to the input path."""
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
self._jwrite.save(path)
def overwrite(self):
"""Overwrites if the output path already exists."""
self._jwrite.overwrite()
return self
def option(self, key, value):
self._jwrite.option(key, value)
return self
def session(self, sparkSession):
"""Sets the Spark Session to use for saving."""
self._jwrite.session(sparkSession._jsparkSession)
return self
@inherit_doc
class GeneralJavaMLWriter(JavaMLWriter):
"""
(Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
super(GeneralJavaMLWriter, self).__init__(instance)
def format(self, source):
"""
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
name for export).
"""
self._jwrite.format(source)
return self
@inherit_doc
class MLWritable(object):
"""
Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0
"""
def write(self):
"""Returns an MLWriter instance for this ML instance."""
raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of 'write().save(path)'."""
self.write().save(path)
@inherit_doc
class JavaMLWritable(MLWritable):
"""
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
"""
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
@inherit_doc
class GeneralJavaMLWritable(JavaMLWritable):
"""
(Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
"""
def write(self):
"""Returns an GeneralMLWriter instance for this ML instance."""
return GeneralJavaMLWriter(self)
@inherit_doc
class MLReader(BaseReadWrite):
"""
Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
def __init__(self):
super(MLReader, self).__init__()
def load(self, path):
"""Load the ML instance from the input path."""
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
@inherit_doc
class JavaMLReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
"""
def __init__(self, clazz):
super(JavaMLReader, self).__init__()
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
def load(self, path):
"""Load the ML instance from the input path."""
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
if not hasattr(self._clazz, "_from_java"):
raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
% self._clazz)
return self._clazz._from_java(java_obj)
def session(self, sparkSession):
"""Sets the Spark Session to use for loading."""
self._jread.session(sparkSession._jsparkSession)
return self
@classmethod
def _java_loader_class(cls, clazz):
"""
Returns the full class name of the Java ML instance. The default
implementation replaces "pyspark" by "org.apache.spark" in
the Python full class name.
"""
java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
if clazz.__name__ in ("Pipeline", "PipelineModel"):
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
java_package = ".".join(java_package.split(".")[0:-1])
return java_package + "." + clazz.__name__
@classmethod
def _load_java_obj(cls, clazz):
"""Load the peer Java object of the ML instance."""
java_class = cls._java_loader_class(clazz)
java_obj = _jvm()
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj
@inherit_doc
class MLReadable(object):
"""
Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0
"""
@classmethod
def read(cls):
"""Returns an MLReader instance for this class."""
raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
@inherit_doc
class JavaMLReadable(MLReadable):
"""
(Private) Mixin for instances that provide JavaMLReader.
"""
@classmethod
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
@inherit_doc
class JavaPredictionModel():
"""
(Private) Java Model for prediction tasks (regression and classification).
To be mixed in with class:`pyspark.ml.JavaModel`
"""
@property
@since("2.1.0")
def numFeatures(self):
"""
Returns the number of features the model was trained on. If unknown, returns -1
"""
return self._call_java("numFeatures")
@inherit_doc
class DefaultParamsWritable(MLWritable):
"""
.. note:: DeveloperApi
Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params`
class stores all data as :py:class:`Param` values, then extending this trait will provide
a default implementation of writing saved instances of the class.
This only handles simple :py:class:`Param` types; e.g., it will not handle
:py:class:`Dataset`. See :py:class:`DefaultParamsReadable`, the counterpart to this trait.
.. versionadded:: 2.3.0
"""
def write(self):
"""Returns a DefaultParamsWriter instance for this class."""
from pyspark.ml.param import Params
if isinstance(self, Params):
return DefaultParamsWriter(self)
else:
raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " +
" extend Params.", type(self))
@inherit_doc
class DefaultParamsWriter(MLWriter):
"""
.. note:: DeveloperApi
Specialization of :py:class:`MLWriter` for :py:class:`Params` types
Class for writing Estimators and Transformers whose parameters are JSON-serializable.
.. versionadded:: 2.3.0
"""
def __init__(self, instance):
super(DefaultParamsWriter, self).__init__()
self.instance = instance
def saveImpl(self, path):
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
@staticmethod
def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
"""
Saves metadata + Params to: path + "/metadata"
- class
- timestamp
- sparkVersion
- uid
- paramMap
- defaultParamMap (since 2.4.0)
- (optionally, extra metadata)
:param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
:param paramMap: If given, this is saved in the "paramMap" field.
"""
metadataPath = os.path.join(path, "metadata")
metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
sc,
extraMetadata,
paramMap)
sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
@staticmethod
def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
"""
Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save.
This is useful for ensemble models which need to save metadata for many sub-models.
.. note:: :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes.
"""
uid = instance.uid
cls = instance.__module__ + '.' + instance.__class__.__name__
# User-supplied param values
params = instance._paramMap
jsonParams = {}
if paramMap is not None:
jsonParams = paramMap
else:
for p in params:
jsonParams[p.name] = params[p]
# Default param values
jsonDefaultParams = {}
for p in instance._defaultParamMap:
jsonDefaultParams[p.name] = instance._defaultParamMap[p]
basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
"sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
"defaultParamMap": jsonDefaultParams}
if extraMetadata is not None:
basicMetadata.update(extraMetadata)
return json.dumps(basicMetadata, separators=[',', ':'])
@inherit_doc
class DefaultParamsReadable(MLReadable):
"""
.. note:: DeveloperApi
Helper trait for making simple :py:class:`Params` types readable.
If a :py:class:`Params` class stores all data as :py:class:`Param` values,
then extending this trait will provide a default implementation of reading saved
instances of the class. This only handles simple :py:class:`Param` types;
e.g., it will not handle :py:class:`Dataset`. See :py:class:`DefaultParamsWritable`,
the counterpart to this trait.
.. versionadded:: 2.3.0
"""
@classmethod
def read(cls):
"""Returns a DefaultParamsReader instance for this class."""
return DefaultParamsReader(cls)
@inherit_doc
class DefaultParamsReader(MLReader):
"""
.. note:: DeveloperApi
Specialization of :py:class:`MLReader` for :py:class:`Params` types
Default :py:class:`MLReader` implementation for transformers and estimators that
contain basic (json-serializable) params and no data. This will not handle
more complex params or types with data (e.g., models with coefficients).
.. versionadded:: 2.3.0
"""
def __init__(self, cls):
super(DefaultParamsReader, self).__init__()
self.cls = cls
@staticmethod
def __get_class(clazz):
"""
Loads Python class from its name.
"""
parts = clazz.split('.')
module = ".".join(parts[:-1])
m = __import__(module)
for comp in parts[1:]:
m = getattr(m, comp)
return m
def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
py_type = DefaultParamsReader.__get_class(metadata['class'])
instance = py_type()
instance._resetUid(metadata['uid'])
DefaultParamsReader.getAndSetParams(instance, metadata)
return instance
@staticmethod
def loadMetadata(path, sc, expectedClassName=""):
"""
Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
:param expectedClassName: If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
metadataStr = sc.textFile(metadataPath, 1).first()
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
return loadedVals
@staticmethod
def _parseMetaData(metadataStr, expectedClassName=""):
"""
Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`.
This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`.
:param metadataStr: JSON string of metadata
:param expectedClassName: If non empty, this is checked against the loaded metadata.
"""
metadata = json.loads(metadataStr)
className = metadata['class']
if len(expectedClassName) > 0:
assert className == expectedClassName, "Error loading metadata: Expected " + \
"class name {} but found class name {}".format(expectedClassName, className)
return metadata
@staticmethod
def getAndSetParams(instance, metadata):
"""
Extract Params from metadata, and set them in the instance.
"""
# Set user-supplied param values
for paramName in metadata['paramMap']:
param = instance.getParam(paramName)
paramValue = metadata['paramMap'][paramName]
instance.set(param, paramValue)
# Set default param values
majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
major = majorAndMinorVersions[0]
minor = majorAndMinorVersions[1]
# For metadata file prior to Spark 2.4, there is no default section.
if major > 2 or (major == 2 and minor >= 4):
assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
"`defaultParamMap` section not found"
for paramName in metadata['defaultParamMap']:
paramValue = metadata['defaultParamMap'][paramName]
instance._setDefault(**{paramName: paramValue})
@staticmethod
def loadParamsInstance(path, sc):
"""
Load a :py:class:`Params` instance from the given path, and return it.
This assumes the instance inherits from :py:class:`MLReadable`.
"""
metadata = DefaultParamsReader.loadMetadata(path, sc)
pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
py_type = DefaultParamsReader.__get_class(pythonClassName)
instance = py_type.load(path)
return instance
@inherit_doc
class HasTrainingSummary(object):
"""
Base class for models that provides Training summary.
.. versionadded:: 3.0.0
"""
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
"""
Gets summary of the model trained on the training set. An exception is thrown if
no summary exists.
"""
return (self._call_java("summary"))