[SPARK-2313] Use socket to communicate GatewayServer port back to Python driver

This patch changes PySpark so that the GatewayServer's port is communicated back to the Python process that launches it over a local socket instead of a pipe.  The old pipe-based approach was brittle and could fail if `spark-submit` printed unexpected to stdout.

To accomplish this, I wrote a custom `PythonGatewayServer.main()` function to use in place of Py4J's `GatewayServer.main()`.

Closes #3424.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #4603 from JoshRosen/SPARK-2313 and squashes the following commits:

6a7740b [Josh Rosen] Remove EchoOutputThread since it's no longer needed
0db501f [Josh Rosen] Use select() so that we don't block if GatewayServer dies.
9bdb4b6 [Josh Rosen] Handle case where getListeningPort returns -1
3fb7ed1 [Josh Rosen] Remove stdout=PIPE
2458934 [Josh Rosen] Use underscore to mark env var. as private
d12c95d [Josh Rosen] Use Logging and Utils.tryOrExit()
e5f9730 [Josh Rosen] Wrap everything in a giant try-block
2f70689 [Josh Rosen] Use stdin PIPE to share fate with driver
8bf956e [Josh Rosen] Initial cut at passing Py4J gateway port back to driver via socket
This commit is contained in:
Josh Rosen 2015-02-16 15:25:11 -08:00
parent c01c4ebcfe
commit 0cfda8461f
3 changed files with 97 additions and 43 deletions

View file

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import java.io.DataOutputStream
import java.net.Socket
import py4j.GatewayServer
import org.apache.spark.Logging
import org.apache.spark.util.Utils
/**
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
* back to its caller via a callback port specified by the caller.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
def main(args: Array[String]): Unit = Utils.tryOrExit {
// Start a GatewayServer on an ephemeral port
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
logError("GatewayServer failed to bind; exiting")
System.exit(1)
} else {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}
// Communicate the bound port back to the caller via the caller-specified callback port
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
val callbackSocket = new Socket(callbackHost, callbackPort)
val dos = new DataOutputStream(callbackSocket.getOutputStream)
dos.writeInt(boundPort)
dos.close()
callbackSocket.close()
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
// Do nothing
}
logDebug("Exiting due to broken pipe from Python driver")
System.exit(0)
}
}

View file

@ -39,7 +39,6 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}
import org.apache.spark.SPARK_VERSION import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._ import org.apache.spark.deploy.rest._
import org.apache.spark.executor._
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
/** /**
@ -284,8 +283,7 @@ object SparkSubmit {
// If we're running a python app, set the main class to our specific python runner // If we're running a python app, set the main class to our specific python runner
if (args.isPython && deployMode == CLIENT) { if (args.isPython && deployMode == CLIENT) {
if (args.primaryResource == PYSPARK_SHELL) { if (args.primaryResource == PYSPARK_SHELL) {
args.mainClass = "py4j.GatewayServer" args.mainClass = "org.apache.spark.api.python.PythonGatewayServer"
args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0")
} else { } else {
// If a python file is provided, add it to the child arguments and list of files to deploy. // If a python file is provided, add it to the child arguments and list of files to deploy.
// Usage: PythonAppRunner <main python file> <extra python files> [app arguments] // Usage: PythonAppRunner <main python file> <extra python files> [app arguments]

View file

@ -17,19 +17,20 @@
import atexit import atexit
import os import os
import sys import select
import signal import signal
import shlex import shlex
import socket
import platform import platform
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from pyspark.serializers import read_int
def launch_gateway(): def launch_gateway():
SPARK_HOME = os.environ["SPARK_HOME"] SPARK_HOME = os.environ["SPARK_HOME"]
gateway_port = -1
if "PYSPARK_GATEWAY_PORT" in os.environ: if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
else: else:
@ -41,36 +42,42 @@ def launch_gateway():
submit_args = submit_args if submit_args is not None else "" submit_args = submit_args if submit_args is not None else ""
submit_args = shlex.split(submit_args) submit_args = shlex.split(submit_args)
command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"]
# Start a socket that will be used by PythonGatewayServer to communicate its port to us
callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
callback_socket.bind(('127.0.0.1', 0))
callback_socket.listen(1)
callback_host, callback_port = callback_socket.getsockname()
env = dict(os.environ)
env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
# Launch the Java gateway.
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
if not on_windows: if not on_windows:
# Don't send ctrl-c / SIGINT to the Java gateway: # Don't send ctrl-c / SIGINT to the Java gateway:
def preexec_func(): def preexec_func():
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
env = dict(os.environ)
env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits
proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
else: else:
# preexec_fn not supported on Windows # preexec_fn not supported on Windows
proc = Popen(command, stdout=PIPE, stdin=PIPE) proc = Popen(command, stdin=PIPE, env=env)
try: gateway_port = None
# We use select() here in order to avoid blocking indefinitely if the subprocess dies
# before connecting
while gateway_port is None and proc.poll() is None:
timeout = 1 # (seconds)
readable, _, _ = select.select([callback_socket], [], [], timeout)
if callback_socket in readable:
gateway_connection = callback_socket.accept()[0]
# Determine which ephemeral port the server started on: # Determine which ephemeral port the server started on:
gateway_port = proc.stdout.readline() gateway_port = read_int(gateway_connection.makefile())
gateway_port = int(gateway_port) gateway_connection.close()
except ValueError: callback_socket.close()
# Grab the remaining lines of stdout if gateway_port is None:
(stdout, _) = proc.communicate() raise Exception("Java gateway process exited before sending the driver its port number")
exit_code = proc.poll()
error_msg = "Launching GatewayServer failed"
error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n"
error_msg += "Warning: Expected GatewayServer to output a port, but found "
if gateway_port == "" and stdout == "":
error_msg += "no output.\n"
else:
error_msg += "the following:\n\n"
error_msg += "--------------------------------------------------------------\n"
error_msg += gateway_port + stdout
error_msg += "--------------------------------------------------------------\n"
raise Exception(error_msg)
# In Windows, ensure the Java child processes do not linger after Python has exited. # In Windows, ensure the Java child processes do not linger after Python has exited.
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
@ -88,21 +95,6 @@ def launch_gateway():
Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
atexit.register(killChild) atexit.register(killChild)
# Create a thread to echo output from the GatewayServer, which is required
# for Java log output to show up:
class EchoOutputThread(Thread):
def __init__(self, stream):
Thread.__init__(self)
self.daemon = True
self.stream = stream
def run(self):
while True:
line = self.stream.readline()
sys.stderr.write(line)
EchoOutputThread(proc.stdout).start()
# Connect to the gateway # Connect to the gateway
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)