[SPARK-35946][PYTHON] Respect Py4J server in InheritableThread API

### What changes were proposed in this pull request?

Currently ,we sets the environment variable `PYSPARK_PIN_THREAD` at the client side of `InhertiableThread` API for Py4J (`python/pyspark/util.py`). If the Py4J gateway is created somewhere else (e.g., Zeppelin, etc), it could introduce a breakage at:

```python
from pyspark import SparkContext
jvm = SparkContext._jvm
thread_connection = jvm._gateway_client.get_thread_connection()
# `AttributeError: 'GatewayClient' object has no attribute 'get_thread_connection'` (non-pinned thread mode)
# `get_thread_connection` is only in 'ClientServer' (pinned thread mode)
```

This PR proposes to check the given gateway created, and do the pinned thread mode behaviour accordingly so we can avoid any breakage when Py4J server/gateway is created separately from somewhere else without a pinned thread mode.

### Why are the changes needed?

To avoid any potential breakage.

### Does this PR introduce _any_ user-facing change?

No, the change happened only in the master (fdd7ca5f4e).

### How was this patch tested?

This is actually a partial revert of fdd7ca5f4e. As long as the existing tests pass, I guess we're all good.

I also manually tested to make doubly sure:

**Before**:

```python
>>> from pyspark import InheritableThread, inheritable_thread_target
>>> InheritableThread(lambda: 1).start()
>>> inheritable_thread_target(lambda: 1)()
Traceback (most recent call last):
  File "/.../python3.8/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../python3.8/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/.../spark/python/pyspark/util.py", line 361, in copy_local_properties
    InheritableThread._clean_py4j_conn_for_current_thread()
  File "/.../spark/python/pyspark/util.py", line 381, in _clean_py4j_conn_for_current_thread
    thread_connection = jvm._gateway_client.get_thread_connection()
AttributeError: 'GatewayClient' object has no attribute 'get_thread_connection'

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../spark/python/pyspark/util.py", line 324, in wrapped
    InheritableThread._clean_py4j_conn_for_current_thread()
  File "/.../spark/python/pyspark/util.py", line 381, in _clean_py4j_conn_for_current_thread
    thread_connection = jvm._gateway_client.get_thread_connection()
AttributeError: 'GatewayClient' object has no attribute 'get_thread_connection'
```

**After**:

```python
>>> from pyspark import InheritableThread, inheritable_thread_target
>>> InheritableThread(lambda: 1).start()
>>> inheritable_thread_target(lambda: 1)()
1
```

Closes #33147 from HyukjinKwon/SPARK-35946.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
Hyukjin Kwon 2021-06-29 22:18:54 -07:00 committed by Dongjoon Hyun
parent 064230de97
commit 8d28839689

View file

@ -26,6 +26,8 @@ import threading
import traceback
import types
from py4j.clientserver import ClientServer
__all__ = [] # type: ignore
@ -308,7 +310,9 @@ def inheritable_thread_target(f):
"""
from pyspark import SparkContext
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
# NOTICE the internal difference vs `InheritableThread`. `InheritableThread`
# copies local properties when the thread starts but `inheritable_thread_target`
# copies when the function is wrapped.
@ -350,7 +354,8 @@ class InheritableThread(threading.Thread):
def __init__(self, target, *args, **kwargs):
from pyspark import SparkContext
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
def copy_local_properties(*a, **k):
# self._props is set before starting the thread to match the behavior with JVM.
assert hasattr(self, "_props")
@ -368,7 +373,9 @@ class InheritableThread(threading.Thread):
def start(self, *args, **kwargs):
from pyspark import SparkContext
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
# Local property copy should happen in Thread.start to mimic JVM's behavior.
self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
return super(InheritableThread, self).start(*args, **kwargs)