[SPARK-25253][PYSPARK] Refactor local connection & auth code
This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm. Also it gives consistent ipv6 and error handling. Two other incidental changes, that shouldn't matter: 1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator) 2) for `rdd._load_from_socket`, the timeout is only increased after authentication. Closes #22247 from squito/py_connection_refactor. Authored-by: Imran Rashid <irashid@cloudera.com> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
parent
68ec207a32
commit
38391c9aa8
|
@ -216,6 +216,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
sock = serverSocket.get.accept()
|
sock = serverSocket.get.accept()
|
||||||
// Wait for function call from python side.
|
// Wait for function call from python side.
|
||||||
sock.setSoTimeout(10000)
|
sock.setSoTimeout(10000)
|
||||||
|
authHelper.authClient(sock)
|
||||||
val input = new DataInputStream(sock.getInputStream())
|
val input = new DataInputStream(sock.getInputStream())
|
||||||
input.readInt() match {
|
input.readInt() match {
|
||||||
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
|
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
|
||||||
|
@ -334,8 +335,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
def barrierAndServe(sock: Socket): Unit = {
|
def barrierAndServe(sock: Socket): Unit = {
|
||||||
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
|
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
|
||||||
|
|
||||||
authHelper.authClient(sock)
|
|
||||||
|
|
||||||
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
|
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
|
||||||
try {
|
try {
|
||||||
context.asInstanceOf[BarrierTaskContext].barrier()
|
context.asInstanceOf[BarrierTaskContext].barrier()
|
||||||
|
|
|
@ -134,7 +134,7 @@ def launch_gateway(conf=None):
|
||||||
return gateway
|
return gateway
|
||||||
|
|
||||||
|
|
||||||
def do_server_auth(conn, auth_secret):
|
def _do_server_auth(conn, auth_secret):
|
||||||
"""
|
"""
|
||||||
Performs the authentication protocol defined by the SocketAuthHelper class on the given
|
Performs the authentication protocol defined by the SocketAuthHelper class on the given
|
||||||
file-like object 'conn'.
|
file-like object 'conn'.
|
||||||
|
@ -147,6 +147,36 @@ def do_server_auth(conn, auth_secret):
|
||||||
raise Exception("Unexpected reply from iterator server.")
|
raise Exception("Unexpected reply from iterator server.")
|
||||||
|
|
||||||
|
|
||||||
|
def local_connect_and_auth(port, auth_secret):
|
||||||
|
"""
|
||||||
|
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
|
||||||
|
Handles IPV4 & IPV6, does some error handling.
|
||||||
|
:param port
|
||||||
|
:param auth_secret
|
||||||
|
:return: a tuple with (sockfile, sock)
|
||||||
|
"""
|
||||||
|
sock = None
|
||||||
|
errors = []
|
||||||
|
# Support for both IPv4 and IPv6.
|
||||||
|
# On most of IPv6-ready systems, IPv6 will take precedence.
|
||||||
|
for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
|
||||||
|
af, socktype, proto, _, sa = res
|
||||||
|
try:
|
||||||
|
sock = socket.socket(af, socktype, proto)
|
||||||
|
sock.settimeout(15)
|
||||||
|
sock.connect(sa)
|
||||||
|
sockfile = sock.makefile("rwb", 65536)
|
||||||
|
_do_server_auth(sockfile, auth_secret)
|
||||||
|
return (sockfile, sock)
|
||||||
|
except socket.error as e:
|
||||||
|
emsg = _exception_message(e)
|
||||||
|
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
|
||||||
|
sock.close()
|
||||||
|
sock = None
|
||||||
|
else:
|
||||||
|
raise Exception("could not open socket: %s" % errors)
|
||||||
|
|
||||||
|
|
||||||
def ensure_callback_server_started(gw):
|
def ensure_callback_server_started(gw):
|
||||||
"""
|
"""
|
||||||
Start callback server if not already started. The callback server is needed if the Java
|
Start callback server if not already started. The callback server is needed if the Java
|
||||||
|
|
|
@ -39,7 +39,7 @@ if sys.version > '3':
|
||||||
else:
|
else:
|
||||||
from itertools import imap as map, ifilter as filter
|
from itertools import imap as map, ifilter as filter
|
||||||
|
|
||||||
from pyspark.java_gateway import do_server_auth
|
from pyspark.java_gateway import local_connect_and_auth
|
||||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
|
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
|
||||||
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
|
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
|
||||||
|
@ -141,33 +141,10 @@ def _parse_memory(s):
|
||||||
|
|
||||||
|
|
||||||
def _load_from_socket(sock_info, serializer):
|
def _load_from_socket(sock_info, serializer):
|
||||||
port, auth_secret = sock_info
|
(sockfile, sock) = local_connect_and_auth(*sock_info)
|
||||||
sock = None
|
|
||||||
errors = []
|
|
||||||
# Support for both IPv4 and IPv6.
|
|
||||||
# On most of IPv6-ready systems, IPv6 will take precedence.
|
|
||||||
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
|
|
||||||
af, socktype, proto, canonname, sa = res
|
|
||||||
sock = socket.socket(af, socktype, proto)
|
|
||||||
try:
|
|
||||||
sock.settimeout(15)
|
|
||||||
sock.connect(sa)
|
|
||||||
except socket.error as e:
|
|
||||||
emsg = _exception_message(e)
|
|
||||||
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
|
|
||||||
sock.close()
|
|
||||||
sock = None
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
if not sock:
|
|
||||||
raise Exception("could not open socket: %s" % errors)
|
|
||||||
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
|
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
|
||||||
# operation, it will very possibly fail. See SPARK-18281.
|
# operation, it will very possibly fail. See SPARK-18281.
|
||||||
sock.settimeout(None)
|
sock.settimeout(None)
|
||||||
|
|
||||||
sockfile = sock.makefile("rwb", 65536)
|
|
||||||
do_server_auth(sockfile, auth_secret)
|
|
||||||
|
|
||||||
# The socket will be automatically closed when garbage-collected.
|
# The socket will be automatically closed when garbage-collected.
|
||||||
return serializer.load_stream(sockfile)
|
return serializer.load_stream(sockfile)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from pyspark.java_gateway import do_server_auth
|
from pyspark.java_gateway import local_connect_and_auth
|
||||||
from pyspark.serializers import write_int, UTF8Deserializer
|
from pyspark.serializers import write_int, UTF8Deserializer
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,38 +108,14 @@ def _load_from_socket(port, auth_secret):
|
||||||
"""
|
"""
|
||||||
Load data from a given socket, this is a blocking method thus only return when the socket
|
Load data from a given socket, this is a blocking method thus only return when the socket
|
||||||
connection has been closed.
|
connection has been closed.
|
||||||
|
|
||||||
This is copied from context.py, while modified the message protocol.
|
|
||||||
"""
|
"""
|
||||||
sock = None
|
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
|
||||||
# Support for both IPv4 and IPv6.
|
# The barrier() call may block forever, so no timeout
|
||||||
# On most of IPv6-ready systems, IPv6 will take precedence.
|
|
||||||
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
|
|
||||||
af, socktype, proto, canonname, sa = res
|
|
||||||
sock = socket.socket(af, socktype, proto)
|
|
||||||
try:
|
|
||||||
# Do not allow timeout for socket reading operation.
|
|
||||||
sock.settimeout(None)
|
sock.settimeout(None)
|
||||||
sock.connect(sa)
|
|
||||||
except socket.error:
|
|
||||||
sock.close()
|
|
||||||
sock = None
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
if not sock:
|
|
||||||
raise Exception("could not open socket")
|
|
||||||
|
|
||||||
# We don't really need a socket file here, it's just for convenience that we can reuse the
|
|
||||||
# do_server_auth() function and data serialization methods.
|
|
||||||
sockfile = sock.makefile("rwb", 65536)
|
|
||||||
|
|
||||||
# Make a barrier() function call.
|
# Make a barrier() function call.
|
||||||
write_int(BARRIER_FUNCTION, sockfile)
|
write_int(BARRIER_FUNCTION, sockfile)
|
||||||
sockfile.flush()
|
sockfile.flush()
|
||||||
|
|
||||||
# Do server auth.
|
|
||||||
do_server_auth(sockfile, auth_secret)
|
|
||||||
|
|
||||||
# Collect result.
|
# Collect result.
|
||||||
res = UTF8Deserializer().loads(sockfile)
|
res = UTF8Deserializer().loads(sockfile)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ import traceback
|
||||||
|
|
||||||
from pyspark.accumulators import _accumulatorRegistry
|
from pyspark.accumulators import _accumulatorRegistry
|
||||||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||||
from pyspark.java_gateway import do_server_auth
|
from pyspark.java_gateway import local_connect_and_auth
|
||||||
from pyspark.taskcontext import BarrierTaskContext, TaskContext
|
from pyspark.taskcontext import BarrierTaskContext, TaskContext
|
||||||
from pyspark.files import SparkFiles
|
from pyspark.files import SparkFiles
|
||||||
from pyspark.rdd import PythonEvalType
|
from pyspark.rdd import PythonEvalType
|
||||||
|
@ -387,8 +387,5 @@ if __name__ == '__main__':
|
||||||
# Read information about how to connect back to the JVM from the environment.
|
# Read information about how to connect back to the JVM from the environment.
|
||||||
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
|
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
|
||||||
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
|
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
|
||||||
sock.connect(("127.0.0.1", java_port))
|
|
||||||
sock_file = sock.makefile("rwb", 65536)
|
|
||||||
do_server_auth(sock_file, auth_secret)
|
|
||||||
main(sock_file, sock_file)
|
main(sock_file, sock_file)
|
||||||
|
|
Loading…
Reference in a new issue