Send PySpark commands as bytes insetad of strings.
This commit is contained in:
parent
cbb7f04aef
commit
ffa5bedf46
|
@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
|||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.rdd.PipedRDD
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
private[spark] class PythonRDD[T: ClassManifest](
|
||||
parent: RDD[T],
|
||||
command: Seq[String],
|
||||
command: Array[Byte],
|
||||
envVars: JMap[String, String],
|
||||
pythonIncludes: JList[String],
|
||||
preservePartitoning: Boolean,
|
||||
|
@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
|
||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||
// using a standard StringTokenizer (i.e. by spaces)
|
||||
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
|
||||
pythonIncludes: JList[String],
|
||||
preservePartitoning: Boolean, pythonExec: String,
|
||||
broadcastVars: JList[Broadcast[Array[Byte]]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]]) =
|
||||
this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
|
||||
broadcastVars, accumulator)
|
||||
|
||||
override def getPartitions = parent.partitions
|
||||
|
||||
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
||||
|
||||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
||||
val startTime = System.currentTimeMillis
|
||||
val env = SparkEnv.get
|
||||
|
@ -71,7 +59,6 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
SparkEnv.set(env)
|
||||
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
|
||||
val dataOut = new DataOutputStream(stream)
|
||||
val printOut = new PrintWriter(stream)
|
||||
// Partition index
|
||||
dataOut.writeInt(split.index)
|
||||
// sparkFilesDir
|
||||
|
@ -87,17 +74,14 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
dataOut.writeInt(pythonIncludes.length)
|
||||
pythonIncludes.foreach(dataOut.writeUTF)
|
||||
dataOut.flush()
|
||||
// Serialized user code
|
||||
for (elem <- command) {
|
||||
printOut.println(elem)
|
||||
}
|
||||
printOut.flush()
|
||||
// Serialized command:
|
||||
dataOut.writeInt(command.length)
|
||||
dataOut.write(command)
|
||||
// Data values
|
||||
for (elem <- parent.iterator(split, context)) {
|
||||
PythonRDD.writeToStream(elem, dataOut)
|
||||
}
|
||||
dataOut.flush()
|
||||
printOut.flush()
|
||||
worker.shutdownOutput()
|
||||
} catch {
|
||||
case e: IOException =>
|
||||
|
|
|
@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
|
|||
from tempfile import NamedTemporaryFile
|
||||
from threading import Thread
|
||||
|
||||
from pyspark import cloudpickle
|
||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||
BatchedSerializer, pack_long
|
||||
BatchedSerializer, CloudPickleSerializer, pack_long
|
||||
from pyspark.join import python_join, python_left_outer_join, \
|
||||
python_right_outer_join, python_cogroup
|
||||
from pyspark.statcounter import StatCounter
|
||||
|
@ -970,8 +969,8 @@ class PipelinedRDD(RDD):
|
|||
serializer = NoOpSerializer()
|
||||
else:
|
||||
serializer = self.ctx.serializer
|
||||
cmds = [self.func, self._prev_jrdd_deserializer, serializer]
|
||||
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
|
||||
command = (self.func, self._prev_jrdd_deserializer, serializer)
|
||||
pickled_command = CloudPickleSerializer()._dumps(command)
|
||||
broadcast_vars = ListConverter().convert(
|
||||
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
|
||||
self.ctx._gateway._gateway_client)
|
||||
|
@ -982,8 +981,9 @@ class PipelinedRDD(RDD):
|
|||
includes = ListConverter().convert(self.ctx._python_includes,
|
||||
self.ctx._gateway._gateway_client)
|
||||
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||
pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
|
||||
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
|
||||
bytearray(pickled_command), env, includes, self.preservesPartitioning,
|
||||
self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
|
||||
class_manifest)
|
||||
self._jrdd_val = python_rdd.asJavaRDD()
|
||||
return self._jrdd_val
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ import cPickle
|
|||
from itertools import chain, izip, product
|
||||
import marshal
|
||||
import struct
|
||||
from pyspark import cloudpickle
|
||||
|
||||
|
||||
__all__ = ["PickleSerializer", "MarshalSerializer"]
|
||||
|
@ -244,6 +245,10 @@ class PickleSerializer(FramedSerializer):
|
|||
def _dumps(self, obj): return cPickle.dumps(obj, 2)
|
||||
_loads = cPickle.loads
|
||||
|
||||
class CloudPickleSerializer(PickleSerializer):
|
||||
|
||||
def _dumps(self, obj): return cloudpickle.dumps(obj, 2)
|
||||
|
||||
|
||||
class MarshalSerializer(FramedSerializer):
|
||||
"""
|
||||
|
|
|
@ -23,7 +23,6 @@ import sys
|
|||
import time
|
||||
import socket
|
||||
import traceback
|
||||
from base64 import standard_b64decode
|
||||
# CloudPickler needs to be imported so that depicklers are registered using the
|
||||
# copy_reg module.
|
||||
from pyspark.accumulators import _accumulatorRegistry
|
||||
|
@ -38,11 +37,6 @@ pickleSer = PickleSerializer()
|
|||
mutf8_deserializer = MUTF8Deserializer()
|
||||
|
||||
|
||||
def load_obj(infile):
|
||||
decoded = standard_b64decode(infile.readline().strip())
|
||||
return pickleSer._loads(decoded)
|
||||
|
||||
|
||||
def report_times(outfile, boot, init, finish):
|
||||
write_int(SpecialLengths.TIMING_DATA, outfile)
|
||||
write_long(1000 * boot, outfile)
|
||||
|
@ -75,10 +69,8 @@ def main(infile, outfile):
|
|||
filename = mutf8_deserializer._loads(infile)
|
||||
sys.path.append(os.path.join(spark_files_dir, filename))
|
||||
|
||||
# Load this stage's function and serializer:
|
||||
func = load_obj(infile)
|
||||
deserializer = load_obj(infile)
|
||||
serializer = load_obj(infile)
|
||||
command = pickleSer._read_with_length(infile)
|
||||
(func, deserializer, serializer) = command
|
||||
init_time = time.time()
|
||||
try:
|
||||
iterator = deserializer.load_stream(infile)
|
||||
|
|
Loading…
Reference in a new issue