[SPARK-35834][PYTHON] Use the same cleanup logic as Py4J in inheritable thread API

### What changes were proposed in this pull request?

This PR fixes the cleanup logic in inheritable thread API by following Py4J cleanup logic at https://github.com/bartdag/py4j/blob/master/py4j-python/src/py4j/clientserver.py#L269-L278.

Currently the tests that use `inheritable_thread_target` are flaky (https://github.com/apache/spark/runs/2870944288):

```
======================================================================
ERROR [71.813s]: test_save_load_pipeline_estimator (pyspark.ml.tests.test_tuning.CrossValidatorTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/__w/spark/spark/python/pyspark/ml/tests/test_tuning.py", line 589, in test_save_load_pipeline_estimator
    self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
  File "/__w/spark/spark/python/pyspark/ml/tests/test_tuning.py", line 572, in _run_test_save_load_pipeline_estimator
    cvModel2 = crossval2.fit(training)
  File "/__w/spark/spark/python/pyspark/ml/base.py", line 161, in fit
    return self._fit(dataset)
  File "/__w/spark/spark/python/pyspark/ml/tuning.py", line 747, in _fit
    bestModel = est.fit(dataset, epm[bestIndex])
  File "/__w/spark/spark/python/pyspark/ml/base.py", line 159, in fit
    return self.copy(params)._fit(dataset)
  File "/__w/spark/spark/python/pyspark/ml/pipeline.py", line 114, in _fit
    model = stage.fit(dataset)
  File "/__w/spark/spark/python/pyspark/ml/base.py", line 161, in fit
    return self._fit(dataset)
  File "/__w/spark/spark/python/pyspark/ml/pipeline.py", line 114, in _fit
    model = stage.fit(dataset)
  File "/__w/spark/spark/python/pyspark/ml/base.py", line 161, in fit
    return self._fit(dataset)
  File "/__w/spark/spark/python/pyspark/ml/classification.py", line 2924, in _fit
    models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
  File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 266, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 644, in get
    raise self._value
  File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/__t/Python/3.6.13/x64/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
    return list(map(*args))
  File "/__w/spark/spark/python/pyspark/util.py", line 324, in wrapped
    InheritableThread._clean_py4j_conn_for_current_thread()
  File "/__w/spark/spark/python/pyspark/util.py", line 389, in _clean_py4j_conn_for_current_thread
    del connections[i]
IndexError: deque index out of range

----------------------------------------------------------------------
```

This seems to be because the connection deque `jvm._gateway_client.deque` is accessed, and modified by other threads. Therefore, the number of threads could be changed in the middle. Using `SparkContext._lock` doesn't protect because the deque can be updated for every Java instance access in Py4J.

This PR proposes to use the atomic `deque.remove` in the problematic dequeue alone with try-catch on `ValueError` in case it's [deleted by Py4J](https://github.com/bartdag/py4j/blob/master/py4j-python/src/py4j/clientserver.py#L269-L278).

### Why are the changes needed?

To fix the flakiness in the tests, and avoid possible breakage in user application by using this API.

### Does this PR introduce _any_ user-facing change?

If users were dependent on InheritableThread with pinned thread mode on, they might have faced such issues intermittently. This PR fixes it.

### How was this patch tested?

Manually tested. CI should test it out too.

Closes #32989 from HyukjinKwon/SPARK-35834.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Hyukjin Kwon 2021-06-21 12:00:16 +09:00
parent 653be9d774
commit 248fda3ead

View file

@ -378,20 +378,18 @@ class InheritableThread(threading.Thread):
from pyspark import SparkContext
jvm = SparkContext._jvm
thread_connection = jvm._gateway_client.thread_connection.connection()
thread_connection = jvm._gateway_client.get_thread_connection()
if thread_connection is not None:
connections = jvm._gateway_client.deque
# Reuse the lock for Py4J in PySpark
with SparkContext._lock:
for i in range(len(connections)):
if connections[i] is thread_connection:
connections[i].close()
del connections[i]
break
else:
# Just in case the connection was not closed but removed from the
# queue.
thread_connection.close()
try:
# Dequeue is shared across other threads but it's thread-safe.
# If this function has to be invoked one more time in the same thead
# Py4J will create a new connection automatically.
jvm._gateway_client.deque.remove(thread_connection)
except ValueError:
# Should never reach this point
return
finally:
thread_connection.close()
if __name__ == "__main__":