From ec31e68d5df259e6df001529235d8c906ff02a6f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Jul 2013 06:20:14 +0000 Subject: [PATCH] Fixed PySpark perf regression by not using socket.makefile(), and improved debuggability by letting "print" statements show up in the executor's stderr Conflicts: core/src/main/scala/spark/api/python/PythonRDD.scala --- .../scala/spark/api/python/PythonRDD.scala | 10 +++-- .../api/python/PythonWorkerFactory.scala | 20 ++++++++- python/pyspark/daemon.py | 42 +++++++++++-------- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 3f283afa62..31d8ea89d4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: JMap[String, String], @@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest]( new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) val printOut = new PrintWriter(stream) // Partition index @@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest]( }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream)) + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj @@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") + + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList @@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) // This happens on the master, where we pass the updates to Python through a socket val socket = new Socket(serverHost, serverPort) val in = socket.getInputStream - val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream)) + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) for (array <- val2) { out.writeInt(array.length) diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala index 8844411d73..85d1dfeac8 100644 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars) daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt() // Redirect the stderr to ours new Thread("stderr reader for " + pythonExec) { @@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } }.start() + + val in = new DataInputStream(daemon.getInputStream) + daemonPort = in.readInt() + + // Redirect further stdout output to our stderr + new Thread("stdout reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() } catch { case e => { stopDaemon() diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 78a2da1e18..78c9457b84 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -1,10 +1,13 @@ import os +import signal +import socket import sys +import traceback import multiprocessing from ctypes import c_bool from errno import EINTR, ECHILD -from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN -from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from socket import AF_INET, SOCK_STREAM, SOMAXCONN +from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main from pyspark.serializers import write_int @@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code): def worker(listen_sock): # Redirect stdout to stderr os.dup2(2, 1) + sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 # Manager sends SIGHUP to request termination of workers in the pool def handle_sighup(*args): assert should_exit() - signal(SIGHUP, handle_sighup) + signal.signal(SIGHUP, handle_sighup) # Cleanup zombie children def handle_sigchld(*args): @@ -51,7 +55,7 @@ def worker(listen_sock): handle_sigchld() elif err.errno != ECHILD: raise - signal(SIGCHLD, handle_sigchld) + signal.signal(SIGCHLD, handle_sigchld) # Handle clients while not should_exit(): @@ -70,19 +74,22 @@ def worker(listen_sock): # never receives SIGCHLD unless a worker crashes. if os.fork() == 0: # Leave the worker pool - signal(SIGHUP, SIG_DFL) + signal.signal(SIGHUP, SIG_DFL) listen_sock.close() - # Handle the client then exit - sockfile = sock.makefile() + # Read the socket using fdopen instead of socket.makefile() because the latter + # seems to be very slow; note that we need to dup() the file descriptor because + # otherwise writes also cause a seek that makes us miss data on the read side. + infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - worker_main(sockfile, sockfile) + worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = exc.code finally: - sockfile.close() - sock.close() - os._exit(compute_real_exit_code(exit_code)) + outfile.flush() + sock.close() + os._exit(compute_real_exit_code(exit_code)) else: sock.close() @@ -92,7 +99,6 @@ def launch_worker(listen_sock): try: worker(listen_sock) except Exception as err: - import traceback traceback.print_exc() os._exit(1) else: @@ -105,7 +111,7 @@ def manager(): os.setpgid(0, 0) # Create a listening socket on the AF_INET loopback interface - listen_sock = socket(AF_INET, SOCK_STREAM) + listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() @@ -121,8 +127,8 @@ def manager(): exit_flag.value = True # Gracefully exit on SIGTERM, don't die on SIGHUP - signal(SIGTERM, lambda signum, frame: shutdown()) - signal(SIGHUP, SIG_IGN) + signal.signal(SIGTERM, lambda signum, frame: shutdown()) + signal.signal(SIGHUP, SIG_IGN) # Cleanup zombie children def handle_sigchld(*args): @@ -133,7 +139,7 @@ def manager(): except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise - signal(SIGCHLD, handle_sigchld) + signal.signal(SIGCHLD, handle_sigchld) # Initialization complete sys.stdout.close() @@ -148,7 +154,7 @@ def manager(): shutdown() raise finally: - signal(SIGTERM, SIG_DFL) + signal.signal(SIGTERM, SIG_DFL) exit_flag.value = True # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP)