Fix performance bug with new Python code not using buffered streams

This commit is contained in:
root 2013-07-01 02:45:00 +00:00
parent 39ae073b5c
commit 3296d132b6
2 changed files with 19 additions and 17 deletions

View file

@ -59,7 +59,8 @@ class SparkEnv (
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
val key = (pythonExec, envVars)
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}

View file

@ -45,37 +45,38 @@ private[spark] class PythonRDD[T: ClassManifest](
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(worker.getOutputStream)
val dOut = new DataOutputStream(worker.getOutputStream)
val stream = new BufferedOutputStream(worker.getOutputStream)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
dOut.writeInt(split.index)
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dOut.writeLong(broadcast.id)
dOut.writeInt(broadcast.value.length)
dOut.write(broadcast.value)
dOut.flush()
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
dataOut.flush()
// Serialized user code
for (elem <- command) {
out.println(elem)
printOut.println(elem)
}
out.flush()
printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dOut)
PythonRDD.writeAsPickle(elem, dataOut)
}
dOut.flush()
out.flush()
dataOut.flush()
printOut.flush()
worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(worker.getInputStream)
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
@ -288,7 +289,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(socket.getOutputStream)
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)