[SPARK-35946][PYTHON] Respect Py4J server in InheritableThread API
### What changes were proposed in this pull request? Currently ,we sets the environment variable `PYSPARK_PIN_THREAD` at the client side of `InhertiableThread` API for Py4J (`python/pyspark/util.py`). If the Py4J gateway is created somewhere else (e.g., Zeppelin, etc), it could introduce a breakage at: ```python from pyspark import SparkContext jvm = SparkContext._jvm thread_connection = jvm._gateway_client.get_thread_connection() # `AttributeError: 'GatewayClient' object has no attribute 'get_thread_connection'` (non-pinned thread mode) # `get_thread_connection` is only in 'ClientServer' (pinned thread mode) ``` This PR proposes to check the given gateway created, and do the pinned thread mode behaviour accordingly so we can avoid any breakage when Py4J server/gateway is created separately from somewhere else without a pinned thread mode. ### Why are the changes needed? To avoid any potential breakage. ### Does this PR introduce _any_ user-facing change? No, the change happened only in the master (fdd7ca5f4e
). ### How was this patch tested? This is actually a partial revert offdd7ca5f4e
. As long as the existing tests pass, I guess we're all good. I also manually tested to make doubly sure: **Before**: ```python >>> from pyspark import InheritableThread, inheritable_thread_target >>> InheritableThread(lambda: 1).start() >>> inheritable_thread_target(lambda: 1)() Traceback (most recent call last): File "/.../python3.8/lib/python3.8/threading.py", line 932, in _bootstrap_inner self.run() File "/.../python3.8/lib/python3.8/threading.py", line 870, in run self._target(*self._args, **self._kwargs) File "/.../spark/python/pyspark/util.py", line 361, in copy_local_properties InheritableThread._clean_py4j_conn_for_current_thread() File "/.../spark/python/pyspark/util.py", line 381, in _clean_py4j_conn_for_current_thread thread_connection = jvm._gateway_client.get_thread_connection() AttributeError: 'GatewayClient' object has no attribute 'get_thread_connection' Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/util.py", line 324, in wrapped InheritableThread._clean_py4j_conn_for_current_thread() File "/.../spark/python/pyspark/util.py", line 381, in _clean_py4j_conn_for_current_thread thread_connection = jvm._gateway_client.get_thread_connection() AttributeError: 'GatewayClient' object has no attribute 'get_thread_connection' ``` **After**: ```python >>> from pyspark import InheritableThread, inheritable_thread_target >>> InheritableThread(lambda: 1).start() >>> inheritable_thread_target(lambda: 1)() 1 ``` Closes #33147 from HyukjinKwon/SPARK-35946. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
parent
064230de97
commit
8d28839689
|
@ -26,6 +26,8 @@ import threading
|
|||
import traceback
|
||||
import types
|
||||
|
||||
from py4j.clientserver import ClientServer
|
||||
|
||||
__all__ = [] # type: ignore
|
||||
|
||||
|
||||
|
@ -308,7 +310,9 @@ def inheritable_thread_target(f):
|
|||
"""
|
||||
from pyspark import SparkContext
|
||||
|
||||
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
|
||||
if isinstance(SparkContext._gateway, ClientServer):
|
||||
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
|
||||
|
||||
# NOTICE the internal difference vs `InheritableThread`. `InheritableThread`
|
||||
# copies local properties when the thread starts but `inheritable_thread_target`
|
||||
# copies when the function is wrapped.
|
||||
|
@ -350,7 +354,8 @@ class InheritableThread(threading.Thread):
|
|||
def __init__(self, target, *args, **kwargs):
|
||||
from pyspark import SparkContext
|
||||
|
||||
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
|
||||
if isinstance(SparkContext._gateway, ClientServer):
|
||||
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
|
||||
def copy_local_properties(*a, **k):
|
||||
# self._props is set before starting the thread to match the behavior with JVM.
|
||||
assert hasattr(self, "_props")
|
||||
|
@ -368,7 +373,9 @@ class InheritableThread(threading.Thread):
|
|||
def start(self, *args, **kwargs):
|
||||
from pyspark import SparkContext
|
||||
|
||||
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
|
||||
if isinstance(SparkContext._gateway, ClientServer):
|
||||
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
|
||||
|
||||
# Local property copy should happen in Thread.start to mimic JVM's behavior.
|
||||
self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
|
||||
return super(InheritableThread, self).start(*args, **kwargs)
|
||||
|
|
Loading…
Reference in a new issue