Rename PythonWorker to PythonWorkerFactory
This commit is contained in:
parent
62c4781400
commit
edb18ca928
|
@ -12,7 +12,7 @@ import spark.storage.BlockManagerMaster
|
|||
import spark.network.ConnectionManager
|
||||
import spark.serializer.{Serializer, SerializerManager}
|
||||
import spark.util.AkkaUtils
|
||||
import spark.api.python.PythonWorker
|
||||
import spark.api.python.PythonWorkerFactory
|
||||
|
||||
|
||||
/**
|
||||
|
@ -41,7 +41,7 @@ class SparkEnv (
|
|||
// If executorId is NOT found, return defaultHostPort
|
||||
var executorIdToHostPort: Option[(String, String) => String]) {
|
||||
|
||||
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]()
|
||||
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
|
||||
|
||||
def stop() {
|
||||
pythonWorkers.foreach { case(key, worker) => worker.stop() }
|
||||
|
@ -57,9 +57,9 @@ class SparkEnv (
|
|||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
def getPythonWorker(pythonExec: String, envVars: Map[String, String]): PythonWorker = {
|
||||
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
|
||||
synchronized {
|
||||
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorker(pythonExec, envVars))
|
||||
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,8 +38,8 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
||||
val startTime = System.currentTimeMillis
|
||||
val worker = SparkEnv.get.getPythonWorker(pythonExec, envVars.toMap).create
|
||||
val env = SparkEnv.get
|
||||
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
|
||||
|
||||
// Start a thread to feed the process input from our parent's iterator
|
||||
new Thread("stdin writer for " + pythonExec) {
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package spark.api.python
|
||||
|
||||
import java.io.DataInputStream
|
||||
import java.io.{DataInputStream, IOException}
|
||||
import java.net.{Socket, SocketException, InetAddress}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import spark._
|
||||
|
||||
private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, String])
|
||||
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
|
||||
extends Logging {
|
||||
var daemon: Process = null
|
||||
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
|
||||
|
@ -56,14 +56,16 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin
|
|||
// Redirect the stderr to ours
|
||||
new Thread("stderr reader for " + pythonExec) {
|
||||
override def run() {
|
||||
// FIXME HACK: We copy the stream on the level of bytes to
|
||||
// attempt to dodge encoding problems.
|
||||
val in = daemon.getErrorStream
|
||||
var buf = new Array[Byte](1024)
|
||||
var len = in.read(buf)
|
||||
while (len != -1) {
|
||||
System.err.write(buf, 0, len)
|
||||
len = in.read(buf)
|
||||
scala.util.control.Exception.ignoring(classOf[IOException]) {
|
||||
// FIXME HACK: We copy the stream on the level of bytes to
|
||||
// attempt to dodge encoding problems.
|
||||
val in = daemon.getErrorStream
|
||||
var buf = new Array[Byte](1024)
|
||||
var len = in.read(buf)
|
||||
while (len != -1) {
|
||||
System.err.write(buf, 0, len)
|
||||
len = in.read(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start()
|
Loading…
Reference in a new issue