Fixed PySpark perf regression by not using socket.makefile(), and improved

debuggability by letting "print" statements show up in the executor's stderr

Conflicts:
	core/src/main/scala/spark/api/python/PythonRDD.scala
This commit is contained in:
root 2013-07-01 06:20:14 +00:00
parent 3296d132b6
commit ec31e68d5d
3 changed files with 50 additions and 22 deletions

View file

@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest](
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest](
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest](
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream))
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream))
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)

View file

@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
daemon = pb.start()
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
}.start()
val in = new DataInputStream(daemon.getInputStream)
daemonPort = in.readInt()
// Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) {
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
var buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
} catch {
case e => {
stopDaemon()

View file

@ -1,10 +1,13 @@
import os
import signal
import socket
import sys
import traceback
import multiprocessing
from ctypes import c_bool
from errno import EINTR, ECHILD
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int
@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code):
def worker(listen_sock):
# Redirect stdout to stderr
os.dup2(2, 1)
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(*args):
assert should_exit()
signal(SIGHUP, handle_sighup)
signal.signal(SIGHUP, handle_sighup)
# Cleanup zombie children
def handle_sigchld(*args):
@ -51,7 +55,7 @@ def worker(listen_sock):
handle_sigchld()
elif err.errno != ECHILD:
raise
signal(SIGCHLD, handle_sigchld)
signal.signal(SIGCHLD, handle_sigchld)
# Handle clients
while not should_exit():
@ -70,19 +74,22 @@ def worker(listen_sock):
# never receives SIGCHLD unless a worker crashes.
if os.fork() == 0:
# Leave the worker pool
signal(SIGHUP, SIG_DFL)
signal.signal(SIGHUP, SIG_DFL)
listen_sock.close()
# Handle the client then exit
sockfile = sock.makefile()
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
worker_main(sockfile, sockfile)
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
exit_code = exc.code
finally:
sockfile.close()
sock.close()
os._exit(compute_real_exit_code(exit_code))
outfile.flush()
sock.close()
os._exit(compute_real_exit_code(exit_code))
else:
sock.close()
@ -92,7 +99,6 @@ def launch_worker(listen_sock):
try:
worker(listen_sock)
except Exception as err:
import traceback
traceback.print_exc()
os._exit(1)
else:
@ -105,7 +111,7 @@ def manager():
os.setpgid(0, 0)
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket(AF_INET, SOCK_STREAM)
listen_sock = socket.socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
@ -121,8 +127,8 @@ def manager():
exit_flag.value = True
# Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown())
signal(SIGHUP, SIG_IGN)
signal.signal(SIGTERM, lambda signum, frame: shutdown())
signal.signal(SIGHUP, SIG_IGN)
# Cleanup zombie children
def handle_sigchld(*args):
@ -133,7 +139,7 @@ def manager():
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
signal(SIGCHLD, handle_sigchld)
signal.signal(SIGCHLD, handle_sigchld)
# Initialization complete
sys.stdout.close()
@ -148,7 +154,7 @@ def manager():
shutdown()
raise
finally:
signal(SIGTERM, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
exit_flag.value = True
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)