diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ee9aee20fa..5b4a0b3c55 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -378,20 +378,18 @@ class InheritableThread(threading.Thread): from pyspark import SparkContext jvm = SparkContext._jvm - thread_connection = jvm._gateway_client.thread_connection.connection() + thread_connection = jvm._gateway_client.get_thread_connection() if thread_connection is not None: - connections = jvm._gateway_client.deque - # Reuse the lock for Py4J in PySpark - with SparkContext._lock: - for i in range(len(connections)): - if connections[i] is thread_connection: - connections[i].close() - del connections[i] - break - else: - # Just in case the connection was not closed but removed from the - # queue. - thread_connection.close() + try: + # Dequeue is shared across other threads but it's thread-safe. + # If this function has to be invoked one more time in the same thead + # Py4J will create a new connection automatically. + jvm._gateway_client.deque.remove(thread_connection) + except ValueError: + # Should never reach this point + return + finally: + thread_connection.close() if __name__ == "__main__":