Fixed PySpark perf regression by not using socket.makefile(), and improved
debuggability by letting "print" statements show up in the executor's stderr Conflicts: core/src/main/scala/spark/api/python/PythonRDD.scala
This commit is contained in:
parent
3296d132b6
commit
ec31e68d5d
|
@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
accumulator: Accumulator[JList[Array[Byte]]])
|
||||
extends RDD[Array[Byte]](parent) {
|
||||
|
||||
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
|
||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||
// using a standard StringTokenizer (i.e. by spaces)
|
||||
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
|
||||
|
@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
new Thread("stdin writer for " + pythonExec) {
|
||||
override def run() {
|
||||
SparkEnv.set(env)
|
||||
val stream = new BufferedOutputStream(worker.getOutputStream)
|
||||
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
|
||||
val dataOut = new DataOutputStream(stream)
|
||||
val printOut = new PrintWriter(stream)
|
||||
// Partition index
|
||||
|
@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
}.start()
|
||||
|
||||
// Return an iterator that read lines from the process's stdout
|
||||
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream))
|
||||
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
|
||||
return new Iterator[Array[Byte]] {
|
||||
def next(): Array[Byte] = {
|
||||
val obj = _nextObj
|
||||
|
@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
|
|||
extends AccumulatorParam[JList[Array[Byte]]] {
|
||||
|
||||
Utils.checkHost(serverHost, "Expected hostname")
|
||||
|
||||
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
|
||||
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
|
||||
|
||||
|
@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
|
|||
// This happens on the master, where we pass the updates to Python through a socket
|
||||
val socket = new Socket(serverHost, serverPort)
|
||||
val in = socket.getInputStream
|
||||
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream))
|
||||
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
|
||||
out.writeInt(val2.size)
|
||||
for (array <- val2) {
|
||||
out.writeInt(array.length)
|
||||
|
|
|
@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
val workerEnv = pb.environment()
|
||||
workerEnv.putAll(envVars)
|
||||
daemon = pb.start()
|
||||
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
|
||||
|
||||
// Redirect the stderr to ours
|
||||
new Thread("stderr reader for " + pythonExec) {
|
||||
|
@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
}
|
||||
}
|
||||
}.start()
|
||||
|
||||
val in = new DataInputStream(daemon.getInputStream)
|
||||
daemonPort = in.readInt()
|
||||
|
||||
// Redirect further stdout output to our stderr
|
||||
new Thread("stdout reader for " + pythonExec) {
|
||||
override def run() {
|
||||
scala.util.control.Exception.ignoring(classOf[IOException]) {
|
||||
// FIXME HACK: We copy the stream on the level of bytes to
|
||||
// attempt to dodge encoding problems.
|
||||
var buf = new Array[Byte](1024)
|
||||
var len = in.read(buf)
|
||||
while (len != -1) {
|
||||
System.err.write(buf, 0, len)
|
||||
len = in.read(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
} catch {
|
||||
case e => {
|
||||
stopDaemon()
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import traceback
|
||||
import multiprocessing
|
||||
from ctypes import c_bool
|
||||
from errno import EINTR, ECHILD
|
||||
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
|
||||
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
|
||||
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
|
||||
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
|
||||
from pyspark.worker import main as worker_main
|
||||
from pyspark.serializers import write_int
|
||||
|
||||
|
@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code):
|
|||
def worker(listen_sock):
|
||||
# Redirect stdout to stderr
|
||||
os.dup2(2, 1)
|
||||
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
|
||||
|
||||
# Manager sends SIGHUP to request termination of workers in the pool
|
||||
def handle_sighup(*args):
|
||||
assert should_exit()
|
||||
signal(SIGHUP, handle_sighup)
|
||||
signal.signal(SIGHUP, handle_sighup)
|
||||
|
||||
# Cleanup zombie children
|
||||
def handle_sigchld(*args):
|
||||
|
@ -51,7 +55,7 @@ def worker(listen_sock):
|
|||
handle_sigchld()
|
||||
elif err.errno != ECHILD:
|
||||
raise
|
||||
signal(SIGCHLD, handle_sigchld)
|
||||
signal.signal(SIGCHLD, handle_sigchld)
|
||||
|
||||
# Handle clients
|
||||
while not should_exit():
|
||||
|
@ -70,19 +74,22 @@ def worker(listen_sock):
|
|||
# never receives SIGCHLD unless a worker crashes.
|
||||
if os.fork() == 0:
|
||||
# Leave the worker pool
|
||||
signal(SIGHUP, SIG_DFL)
|
||||
signal.signal(SIGHUP, SIG_DFL)
|
||||
listen_sock.close()
|
||||
# Handle the client then exit
|
||||
sockfile = sock.makefile()
|
||||
# Read the socket using fdopen instead of socket.makefile() because the latter
|
||||
# seems to be very slow; note that we need to dup() the file descriptor because
|
||||
# otherwise writes also cause a seek that makes us miss data on the read side.
|
||||
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
|
||||
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
|
||||
exit_code = 0
|
||||
try:
|
||||
worker_main(sockfile, sockfile)
|
||||
worker_main(infile, outfile)
|
||||
except SystemExit as exc:
|
||||
exit_code = exc.code
|
||||
exit_code = exc.code
|
||||
finally:
|
||||
sockfile.close()
|
||||
sock.close()
|
||||
os._exit(compute_real_exit_code(exit_code))
|
||||
outfile.flush()
|
||||
sock.close()
|
||||
os._exit(compute_real_exit_code(exit_code))
|
||||
else:
|
||||
sock.close()
|
||||
|
||||
|
@ -92,7 +99,6 @@ def launch_worker(listen_sock):
|
|||
try:
|
||||
worker(listen_sock)
|
||||
except Exception as err:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
os._exit(1)
|
||||
else:
|
||||
|
@ -105,7 +111,7 @@ def manager():
|
|||
os.setpgid(0, 0)
|
||||
|
||||
# Create a listening socket on the AF_INET loopback interface
|
||||
listen_sock = socket(AF_INET, SOCK_STREAM)
|
||||
listen_sock = socket.socket(AF_INET, SOCK_STREAM)
|
||||
listen_sock.bind(('127.0.0.1', 0))
|
||||
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
|
||||
listen_host, listen_port = listen_sock.getsockname()
|
||||
|
@ -121,8 +127,8 @@ def manager():
|
|||
exit_flag.value = True
|
||||
|
||||
# Gracefully exit on SIGTERM, don't die on SIGHUP
|
||||
signal(SIGTERM, lambda signum, frame: shutdown())
|
||||
signal(SIGHUP, SIG_IGN)
|
||||
signal.signal(SIGTERM, lambda signum, frame: shutdown())
|
||||
signal.signal(SIGHUP, SIG_IGN)
|
||||
|
||||
# Cleanup zombie children
|
||||
def handle_sigchld(*args):
|
||||
|
@ -133,7 +139,7 @@ def manager():
|
|||
except EnvironmentError as err:
|
||||
if err.errno not in (ECHILD, EINTR):
|
||||
raise
|
||||
signal(SIGCHLD, handle_sigchld)
|
||||
signal.signal(SIGCHLD, handle_sigchld)
|
||||
|
||||
# Initialization complete
|
||||
sys.stdout.close()
|
||||
|
@ -148,7 +154,7 @@ def manager():
|
|||
shutdown()
|
||||
raise
|
||||
finally:
|
||||
signal(SIGTERM, SIG_DFL)
|
||||
signal.signal(SIGTERM, SIG_DFL)
|
||||
exit_flag.value = True
|
||||
# Send SIGHUP to notify workers of shutdown
|
||||
os.kill(0, SIGHUP)
|
||||
|
|
Loading…
Reference in a new issue