spark-instrumented-optimizer/python/pyspark/__init__.py
reidy-p 4ab9aa0305 [SPARK-33017][PYTHON] Add getCheckpointDir method to PySpark Context
### What changes were proposed in this pull request?

Adding a method to get the checkpoint directory from the PySpark context to match the Scala API

### Why are the changes needed?

To make the Scala and Python APIs consistent and remove the need to use the JavaObject

### Does this PR introduce _any_ user-facing change?

Yes, there is a new method which makes it easier to get the checkpoint directory directly rather than using the JavaObject

#### Previous behaviour:
```python
>>> spark.sparkContext.setCheckpointDir('/tmp/spark/checkpoint/')
>>> sc._jsc.sc().getCheckpointDir().get()
'file:/tmp/spark/checkpoint/63f7b67c-e5dc-4d11-a70c-33554a71717a'
```
This method returns a confusing Scala error if it has not been set
```python
>>> sc._jsc.sc().getCheckpointDir().get()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/paul/Desktop/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py", line 1305, in __call__
  File "/home/paul/Desktop/spark/python/pyspark/sql/utils.py", line 111, in deco
    return f(*a, **kw)
  File "/home/paul/Desktop/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o25.get.
: java.util.NoSuchElementException: None.get
        at scala.None$.get(Option.scala:529)
        at scala.None$.get(Option.scala:527)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:238)
        at java.lang.Thread.run(Thread.java:748)

```

#### New method:
```python
>>> spark.sparkContext.setCheckpointDir('/tmp/spark/checkpoint/')
>>> spark.sparkContext.getCheckpointDir()
'file:/tmp/spark/checkpoint/b38aca2e-8ace-44fc-a4c4-f4e36c2da2a7'
```

``getCheckpointDir()`` returns ``None`` if it has not been set
```python
>>> print(spark.sparkContext.getCheckpointDir())
None
```

### How was this patch tested?

Added to existing unit tests. But I'm not sure how to add a test for the case where ``getCheckpointDir()`` should return ``None`` since the existing checkpoint tests set the checkpoint directory in the ``setUp`` method before any tests are run as far as I can tell.

Closes #29918 from reidy-p/SPARK-33017.

Authored-by: reidy-p <paul_reidy@outlook.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-10-05 11:48:28 +09:00

127 lines
4.7 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
PySpark is the Python API for Spark.
Public classes:
- :class:`SparkContext`:
Main entry point for Spark functionality.
- :class:`RDD`:
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- :class:`Broadcast`:
A broadcast variable that gets reused across tasks.
- :class:`Accumulator`:
An "add-only" shared variable that tasks can only add values to.
- :class:`SparkConf`:
For configuring Spark.
- :class:`SparkFiles`:
Access files shipped with jobs.
- :class:`StorageLevel`:
Finer-grained cache persistence levels.
- :class:`TaskContext`:
Information about the current running task, available on the workers and experimental.
- :class:`RDDBarrier`:
Wraps an RDD under a barrier stage for barrier execution.
- :class:`BarrierTaskContext`:
A :class:`TaskContext` that provides extra info and tooling for barrier execution.
- :class:`BarrierTaskInfo`:
Information about a barrier task.
- :class:`InheritableThread`:
A inheritable thread to use in Spark when the pinned thread mode is on.
"""
from functools import wraps
import types
from pyspark.conf import SparkConf
from pyspark.rdd import RDD, RDDBarrier
from pyspark.files import SparkFiles
from pyspark.status import StatusTracker, SparkJobInfo, SparkStageInfo
from pyspark.util import InheritableThread
from pyspark.storagelevel import StorageLevel
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo
from pyspark.profiler import Profiler, BasicProfiler
from pyspark.version import __version__ # noqa: F401
from pyspark._globals import _NoValue # noqa: F401
def since(version):
"""
A decorator that annotates a function to append the version of Spark the function was added.
"""
import re
indent_p = re.compile(r'\n( +)')
def deco(f):
indents = indent_p.findall(f.__doc__)
indent = ' ' * (min(len(m) for m in indents) if indents else 0)
f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
return f
return deco
def copy_func(f, name=None, sinceversion=None, doc=None):
"""
Returns a function with same code, globals, defaults, closure, and
name (or provide a new name).
"""
# See
# http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python
fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__,
f.__closure__)
# in case f was given attrs (note this dict is a shallow copy):
fn.__dict__.update(f.__dict__)
if doc is not None:
fn.__doc__ = doc
if sinceversion is not None:
fn = since(sinceversion)(fn)
return fn
def keyword_only(func):
"""
A decorator that forces keyword arguments in the wrapped method
and saves actual input keyword arguments in `_input_kwargs`.
.. note:: Should only be used to wrap a method where first arg is `self`
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if len(args) > 0:
raise TypeError("Method %s forces keyword arguments." % func.__name__)
self._input_kwargs = kwargs
return func(self, **kwargs)
return wrapper
# To avoid circular dependencies
from pyspark.context import SparkContext
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, Row # noqa: F401
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
"StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
"RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "InheritableThread",
]