[SPARK-35834][PYTHON] Use the same cleanup logic as Py4J in inheritable thread API
### What changes were proposed in this pull request? This PR fixes the cleanup logic in inheritable thread API by following Py4J cleanup logic at https://github.com/bartdag/py4j/blob/master/py4j-python/src/py4j/clientserver.py#L269-L278. Currently the tests that use `inheritable_thread_target` are flaky (https://github.com/apache/spark/runs/2870944288): ``` ====================================================================== ERROR [71.813s]: test_save_load_pipeline_estimator (pyspark.ml.tests.test_tuning.CrossValidatorTests) ---------------------------------------------------------------------- Traceback (most recent call last): File "/__w/spark/spark/python/pyspark/ml/tests/test_tuning.py", line 589, in test_save_load_pipeline_estimator self._run_test_save_load_pipeline_estimator(DummyLogisticRegression) File "/__w/spark/spark/python/pyspark/ml/tests/test_tuning.py", line 572, in _run_test_save_load_pipeline_estimator cvModel2 = crossval2.fit(training) File "/__w/spark/spark/python/pyspark/ml/base.py", line 161, in fit return self._fit(dataset) File "/__w/spark/spark/python/pyspark/ml/tuning.py", line 747, in _fit bestModel = est.fit(dataset, epm[bestIndex]) File "/__w/spark/spark/python/pyspark/ml/base.py", line 159, in fit return self.copy(params)._fit(dataset) File "/__w/spark/spark/python/pyspark/ml/pipeline.py", line 114, in _fit model = stage.fit(dataset) File "/__w/spark/spark/python/pyspark/ml/base.py", line 161, in fit return self._fit(dataset) File "/__w/spark/spark/python/pyspark/ml/pipeline.py", line 114, in _fit model = stage.fit(dataset) File "/__w/spark/spark/python/pyspark/ml/base.py", line 161, in fit return self._fit(dataset) File "/__w/spark/spark/python/pyspark/ml/classification.py", line 2924, in _fit models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses)) File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 266, in map return self._map_async(func, iterable, mapstar, chunksize).get() File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 644, in get raise self._value File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 119, in worker result = (True, func(*args, **kwds)) File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar return list(map(*args)) File "/__w/spark/spark/python/pyspark/util.py", line 324, in wrapped InheritableThread._clean_py4j_conn_for_current_thread() File "/__w/spark/spark/python/pyspark/util.py", line 389, in _clean_py4j_conn_for_current_thread del connections[i] IndexError: deque index out of range ---------------------------------------------------------------------- ``` This seems to be because the connection deque `jvm._gateway_client.deque` is accessed, and modified by other threads. Therefore, the number of threads could be changed in the middle. Using `SparkContext._lock` doesn't protect because the deque can be updated for every Java instance access in Py4J. This PR proposes to use the atomic `deque.remove` in the problematic dequeue alone with try-catch on `ValueError` in case it's [deleted by Py4J](https://github.com/bartdag/py4j/blob/master/py4j-python/src/py4j/clientserver.py#L269-L278). ### Why are the changes needed? To fix the flakiness in the tests, and avoid possible breakage in user application by using this API. ### Does this PR introduce _any_ user-facing change? If users were dependent on InheritableThread with pinned thread mode on, they might have faced such issues intermittently. This PR fixes it. ### How was this patch tested? Manually tested. CI should test it out too. Closes #32989 from HyukjinKwon/SPARK-35834. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
653be9d774
commit
248fda3ead
|
@ -378,20 +378,18 @@ class InheritableThread(threading.Thread):
|
|||
from pyspark import SparkContext
|
||||
|
||||
jvm = SparkContext._jvm
|
||||
thread_connection = jvm._gateway_client.thread_connection.connection()
|
||||
thread_connection = jvm._gateway_client.get_thread_connection()
|
||||
if thread_connection is not None:
|
||||
connections = jvm._gateway_client.deque
|
||||
# Reuse the lock for Py4J in PySpark
|
||||
with SparkContext._lock:
|
||||
for i in range(len(connections)):
|
||||
if connections[i] is thread_connection:
|
||||
connections[i].close()
|
||||
del connections[i]
|
||||
break
|
||||
else:
|
||||
# Just in case the connection was not closed but removed from the
|
||||
# queue.
|
||||
thread_connection.close()
|
||||
try:
|
||||
# Dequeue is shared across other threads but it's thread-safe.
|
||||
# If this function has to be invoked one more time in the same thead
|
||||
# Py4J will create a new connection automatically.
|
||||
jvm._gateway_client.deque.remove(thread_connection)
|
||||
except ValueError:
|
||||
# Should never reach this point
|
||||
return
|
||||
finally:
|
||||
thread_connection.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue