[SPARK-25253][PYSPARK] Refactor local connection & auth code

This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm.  Also it gives consistent ipv6 and error handling.  Two other incidental changes, that shouldn't matter:
1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator)
2) for `rdd._load_from_socket`, the timeout is only increased after authentication.

Closes #22247 from squito/py_connection_refactor.

Authored-by: Imran Rashid <irashid@cloudera.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
Imran Rashid 2018-08-29 09:47:38 +08:00 committed by hyukjinkwon
parent 68ec207a32
commit 38391c9aa8
5 changed files with 40 additions and 61 deletions

View file

@ -216,6 +216,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
sock = serverSocket.get.accept() sock = serverSocket.get.accept()
// Wait for function call from python side. // Wait for function call from python side.
sock.setSoTimeout(10000) sock.setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(sock.getInputStream()) val input = new DataInputStream(sock.getInputStream())
input.readInt() match { input.readInt() match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
@ -334,8 +335,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
def barrierAndServe(sock: Socket): Unit = { def barrierAndServe(sock: Socket): Unit = {
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
authHelper.authClient(sock)
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try { try {
context.asInstanceOf[BarrierTaskContext].barrier() context.asInstanceOf[BarrierTaskContext].barrier()

View file

@ -134,7 +134,7 @@ def launch_gateway(conf=None):
return gateway return gateway
def do_server_auth(conn, auth_secret): def _do_server_auth(conn, auth_secret):
""" """
Performs the authentication protocol defined by the SocketAuthHelper class on the given Performs the authentication protocol defined by the SocketAuthHelper class on the given
file-like object 'conn'. file-like object 'conn'.
@ -147,6 +147,36 @@ def do_server_auth(conn, auth_secret):
raise Exception("Unexpected reply from iterator server.") raise Exception("Unexpected reply from iterator server.")
def local_connect_and_auth(port, auth_secret):
"""
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
Handles IPV4 & IPV6, does some error handling.
:param port
:param auth_secret
:return: a tuple with (sockfile, sock)
"""
sock = None
errors = []
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, _, sa = res
try:
sock = socket.socket(af, socktype, proto)
sock.settimeout(15)
sock.connect(sa)
sockfile = sock.makefile("rwb", 65536)
_do_server_auth(sockfile, auth_secret)
return (sockfile, sock)
except socket.error as e:
emsg = _exception_message(e)
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
sock.close()
sock = None
else:
raise Exception("could not open socket: %s" % errors)
def ensure_callback_server_started(gw): def ensure_callback_server_started(gw):
""" """
Start callback server if not already started. The callback server is needed if the Java Start callback server if not already started. The callback server is needed if the Java

View file

@ -39,7 +39,7 @@ if sys.version > '3':
else: else:
from itertools import imap as map, ifilter as filter from itertools import imap as map, ifilter as filter
from pyspark.java_gateway import do_server_auth from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
@ -141,33 +141,10 @@ def _parse_memory(s):
def _load_from_socket(sock_info, serializer): def _load_from_socket(sock_info, serializer):
port, auth_secret = sock_info (sockfile, sock) = local_connect_and_auth(*sock_info)
sock = None
errors = []
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = socket.socket(af, socktype, proto)
try:
sock.settimeout(15)
sock.connect(sa)
except socket.error as e:
emsg = _exception_message(e)
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
sock.close()
sock = None
continue
break
if not sock:
raise Exception("could not open socket: %s" % errors)
# The RDD materialization time is unpredicable, if we set a timeout for socket reading # The RDD materialization time is unpredicable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281. # operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None) sock.settimeout(None)
sockfile = sock.makefile("rwb", 65536)
do_server_auth(sockfile, auth_secret)
# The socket will be automatically closed when garbage-collected. # The socket will be automatically closed when garbage-collected.
return serializer.load_stream(sockfile) return serializer.load_stream(sockfile)

View file

@ -18,7 +18,7 @@
from __future__ import print_function from __future__ import print_function
import socket import socket
from pyspark.java_gateway import do_server_auth from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import write_int, UTF8Deserializer from pyspark.serializers import write_int, UTF8Deserializer
@ -108,38 +108,14 @@ def _load_from_socket(port, auth_secret):
""" """
Load data from a given socket, this is a blocking method thus only return when the socket Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed. connection has been closed.
This is copied from context.py, while modified the message protocol.
""" """
sock = None (sockfile, sock) = local_connect_and_auth(port, auth_secret)
# Support for both IPv4 and IPv6. # The barrier() call may block forever, so no timeout
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = socket.socket(af, socktype, proto)
try:
# Do not allow timeout for socket reading operation.
sock.settimeout(None) sock.settimeout(None)
sock.connect(sa)
except socket.error:
sock.close()
sock = None
continue
break
if not sock:
raise Exception("could not open socket")
# We don't really need a socket file here, it's just for convenience that we can reuse the
# do_server_auth() function and data serialization methods.
sockfile = sock.makefile("rwb", 65536)
# Make a barrier() function call. # Make a barrier() function call.
write_int(BARRIER_FUNCTION, sockfile) write_int(BARRIER_FUNCTION, sockfile)
sockfile.flush() sockfile.flush()
# Do server auth.
do_server_auth(sockfile, auth_secret)
# Collect result. # Collect result.
res = UTF8Deserializer().loads(sockfile) res = UTF8Deserializer().loads(sockfile)

View file

@ -28,7 +28,7 @@ import traceback
from pyspark.accumulators import _accumulatorRegistry from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.java_gateway import do_server_auth from pyspark.java_gateway import local_connect_and_auth
from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.files import SparkFiles from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType from pyspark.rdd import PythonEvalType
@ -387,8 +387,5 @@ if __name__ == '__main__':
# Read information about how to connect back to the JVM from the environment. # Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("rwb", 65536)
do_server_auth(sock_file, auth_secret)
main(sock_file, sock_file) main(sock_file, sock_file)