diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 5b4a0b3c55..e075b0460f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -26,6 +26,8 @@ import threading import traceback import types +from py4j.clientserver import ClientServer + __all__ = [] # type: ignore @@ -308,7 +310,9 @@ def inheritable_thread_target(f): """ from pyspark import SparkContext - if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + if isinstance(SparkContext._gateway, ClientServer): + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. + # NOTICE the internal difference vs `InheritableThread`. `InheritableThread` # copies local properties when the thread starts but `inheritable_thread_target` # copies when the function is wrapped. @@ -350,7 +354,8 @@ class InheritableThread(threading.Thread): def __init__(self, target, *args, **kwargs): from pyspark import SparkContext - if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + if isinstance(SparkContext._gateway, ClientServer): + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. def copy_local_properties(*a, **k): # self._props is set before starting the thread to match the behavior with JVM. assert hasattr(self, "_props") @@ -368,7 +373,9 @@ class InheritableThread(threading.Thread): def start(self, *args, **kwargs): from pyspark import SparkContext - if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + if isinstance(SparkContext._gateway, ClientServer): + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. + # Local property copy should happen in Thread.start to mimic JVM's behavior. self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() return super(InheritableThread, self).start(*args, **kwargs)