Add custom serializer support to PySpark.

For now, this only adds MarshalSerializer, but it lays the groundwork
for other supporting custom serializers.  Many of these mechanisms
can also be used to support deserialization of different data formats
sent by Java, such as data encoded by MsgPack.

This also fixes a bug in SparkContext.union().
This commit is contained in:
Josh Rosen 2013-11-05 17:52:39 -08:00
parent 7d68a81a8e
commit cbb7f04aef
9 changed files with 364 additions and 171 deletions

View file

@ -221,18 +221,6 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
val s = elem.getBytes("UTF-8")
val length = 2 + 1 + 4 + s.length + 1
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
}
def writeToStream(elem: Any, dataOut: DataOutputStream) {
elem match {
case bytes: Array[Byte] =>
@ -244,9 +232,7 @@ private[spark] object PythonRDD {
dataOut.writeInt(pair._2.length)
dataOut.write(pair._2)
case str: String =>
// Until we've implemented full custom serializer support, we need to return
// strings as Pickles to properly support union() and cartesian():
writeStringAsPickle(str, dataOut)
dataOut.writeUTF(str)
case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}
@ -271,13 +257,6 @@ private[spark] object PythonRDD {
}
}
private object Pickle {
val PROTO: Byte = 0x80.toByte
val TWO: Byte = 0x02.toByte
val BINUNICODE: Byte = 'X'
val STOP: Byte = '.'
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}

View file

@ -32,6 +32,6 @@ target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
pyspark.rddsampler pyspark.daemon

View file

@ -90,9 +90,11 @@ import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, read_with_length, load_pickle
from pyspark.serializers import read_int, PickleSerializer
pickleSer = PickleSerializer()
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
_accumulatorRegistry = {}
@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = load_pickle(read_with_length(self.rfile))
(aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))

View file

@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@ -51,7 +51,7 @@ class SparkContext(object):
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
environment=None, batchSize=1024, serializer=PickleSerializer()):
"""
Create a new SparkContext.
@ -67,6 +67,7 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
@param serializer: The serializer for RDDs.
>>> from pyspark.context import SparkContext
@ -83,7 +84,13 @@ class SparkContext(object):
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
self.environment = environment or {}
self.batchSize = batchSize # -1 represents a unlimited batch size
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
if batchSize == 1:
self.serializer = self._unbatched_serializer
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
# Create the Java SparkContext through Py4J
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@ -184,15 +191,17 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = min(len(c) // numSlices, self.batchSize)
batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
c = batched(c, batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
else:
serializer = self._unbatched_serializer
serializer.dump_stream(c, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
return RDD(jrdd, self, serializer)
def textFile(self, name, minSplits=None):
"""
@ -201,21 +210,39 @@ class SparkContext(object):
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
return RDD(self._jsc.textFile(name, minSplits), self,
MUTF8Deserializer())
def _checkpointFile(self, name):
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self)
return RDD(jrdd, self, input_deserializer)
def union(self, rdds):
"""
Build the union of a list of RDDs.
This supports unions() of RDDs with different serialized formats,
although this forces them to be reserialized using the default
serializer:
>>> path = os.path.join(tempdir, "union-text.txt")
>>> with open(path, "w") as testFile:
... testFile.write("Hello")
>>> textFile = sc.textFile(path)
>>> textFile.collect()
[u'Hello']
>>> parallelized = sc.parallelize(["World!"])
>>> sorted(sc.union([textFile, parallelized]).collect())
[u'Hello', 'World!']
"""
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
rest = ListConverter().convert(rest, self.gateway._gateway_client)
return RDD(self._jsc.union(first, rest), self)
rest = ListConverter().convert(rest, self._gateway._gateway_client)
return RDD(self._jsc.union(first, rest), self,
rdds[0]._jrdd_deserializer)
def broadcast(self, value):
"""
@ -223,7 +250,9 @@ class SparkContext(object):
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
pickleSer = PickleSerializer()
pickled = pickleSer._dumps(value)
jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
@ -235,7 +264,7 @@ class SparkContext(object):
and floating-point numbers if you do not provide one. For other types,
a custom AccumulatorParam can be used.
"""
if accum_param == None:
if accum_param is None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):

View file

@ -18,7 +18,7 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap, product
from itertools import chain, ifilter, imap
import operator
import os
import sys
@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile
from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
read_from_pickle_file, pack_long
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@ -48,13 +48,12 @@ class RDD(object):
operated on in parallel.
"""
def __init__(self, jrdd, ctx):
def __init__(self, jrdd, ctx, jrdd_deserializer):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx
self._partitionFunc = None
self._stage_input_is_pairs = False
self._jrdd_deserializer = jrdd_deserializer
@property
def context(self):
@ -248,7 +247,23 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
return RDD(self._jrdd.union(other._jrdd), self.ctx)
if self._jrdd_deserializer == other._jrdd_deserializer:
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
self._jrdd_deserializer)
return rdd
else:
# These RDDs contain data in different serialized formats, so we
# must normalize them to the default serializer.
self_copy = self._reserialize()
other_copy = other._reserialize()
return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
self.ctx.serializer)
def _reserialize(self):
if self._jrdd_deserializer == self.ctx.serializer:
return self
else:
return self.map(lambda x: x, preservesPartitioning=True)
def __add__(self, other):
"""
@ -335,18 +350,9 @@ class RDD(object):
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
def unpack_batches(pair):
(x, y) = pair
if type(x) == Batch or type(y) == Batch:
xs = x.items if type(x) == Batch else [x]
ys = y.items if type(y) == Batch else [y]
for pair in product(xs, ys):
yield pair
else:
yield pair
java_cartesian._stage_input_is_pairs = True
return java_cartesian.flatMap(unpack_batches)
deserializer = CartesianDeserializer(self._jrdd_deserializer,
other._jrdd_deserializer)
return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
def groupBy(self, f, numPartitions=None):
"""
@ -405,7 +411,7 @@ class RDD(object):
self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
@ -573,7 +579,7 @@ class RDD(object):
items = []
for partition in range(mapped._jrdd.splits().size()):
iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
@ -737,6 +743,7 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
outputSerializer = self.ctx._unbatched_serializer
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
@ -745,14 +752,14 @@ class RDD(object):
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield pack_long(split)
yield dump_pickle(Batch(items))
yield outputSerializer._dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx)
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
@ -789,7 +796,8 @@ class RDD(object):
numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
for (k, v) in iterator:
for x in iterator:
(k, v) = x
if k not in combiners:
combiners[k] = createCombiner(v)
else:
@ -931,38 +939,38 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
# This transformation is the first in its stage:
self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self._prev_jrdd_deserializer = prev._jrdd_deserializer
else:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
self._prev_jrdd = prev._prev_jrdd
else:
self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self._stage_input_is_pairs = prev._stage_input_is_pairs
self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
func = self.func
if not self._bypass_serializer and self.ctx.batchSize != 1:
oldfunc = self.func
batchSize = self.ctx.batchSize
def batched_func(split, iterator):
return batched(oldfunc(split, iterator), batchSize)
func = batched_func
cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
if self._bypass_serializer:
serializer = NoOpSerializer()
else:
serializer = self.ctx.serializer
cmds = [self.func, self._prev_jrdd_deserializer, serializer]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],

View file

@ -15,8 +15,58 @@
# limitations under the License.
#
import struct
"""
PySpark supports custom serializers for transferring data; this can improve
performance.
By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
C{cPickle} serializer, which can serialize nearly any Python object.
Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
faster.
The serializer is chosen when creating L{SparkContext}:
>>> from pyspark.context import SparkContext
>>> from pyspark.serializers import MarshalSerializer
>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.stop()
By default, PySpark serialize objects in batches; the batch size can be
controlled through SparkContext's C{batchSize} parameter
(the default size is 1024 objects):
>>> sc = SparkContext('local', 'test', batchSize=2)
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
Behind the scenes, this creates a JavaRDD with four partitions, each of
which contains two batches of two objects:
>>> rdd.glom().collect()
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
>>> rdd._jrdd.count()
8L
>>> sc.stop()
A batch size of -1 uses an unlimited batch size, and a size of 1 disables
batching:
>>> sc = SparkContext('local', 'test', batchSize=1)
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
>>> rdd.glom().collect()
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
>>> rdd._jrdd.count()
16L
"""
import cPickle
from itertools import chain, izip, product
import marshal
import struct
__all__ = ["PickleSerializer", "MarshalSerializer"]
class SpecialLengths(object):
@ -25,41 +75,206 @@ class SpecialLengths(object):
TIMING_DATA = -3
class Batch(object):
class Serializer(object):
def dump_stream(self, iterator, stream):
"""
Serialize an iterator of objects to the output stream.
"""
raise NotImplementedError
def load_stream(self, stream):
"""
Return an iterator of deserialized objects from the input stream.
"""
raise NotImplementedError
def _load_stream_without_unbatching(self, stream):
return self.load_stream(stream)
# Note: our notion of "equality" is that output generated by
# equal serializers can be deserialized using the same serializer.
# This default implementation handles the simple cases;
# subclasses should override __eq__ as appropriate.
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not self.__eq__(other)
class FramedSerializer(Serializer):
"""
Used to store multiple RDD entries as a single Java object.
This relieves us from having to explicitly track whether an RDD
is stored as batches of objects and avoids problems when processing
the union() of batched and unbatched RDDs (e.g. the union() of textFile()
with another RDD).
Serializer that writes objects as a stream of (length, data) pairs,
where C{length} is a 32-bit integer and data is C{length} bytes.
"""
def __init__(self, items):
self.items = items
def dump_stream(self, iterator, stream):
for obj in iterator:
self._write_with_length(obj, stream)
def load_stream(self, stream):
while True:
try:
yield self._read_with_length(stream)
except EOFError:
return
def _write_with_length(self, obj, stream):
serialized = self._dumps(obj)
write_int(len(serialized), stream)
stream.write(serialized)
def _read_with_length(self, stream):
length = read_int(stream)
obj = stream.read(length)
if obj == "":
raise EOFError
return self._loads(obj)
def _dumps(self, obj):
"""
Serialize an object into a byte array.
When batching is used, this will be called with an array of objects.
"""
raise NotImplementedError
def _loads(self, obj):
"""
Deserialize an object from a byte array.
"""
raise NotImplementedError
def batched(iterator, batchSize):
if batchSize == -1: # unlimited batch size
yield Batch(list(iterator))
else:
items = []
count = 0
for item in iterator:
items.append(item)
count += 1
if count == batchSize:
yield Batch(items)
items = []
count = 0
if items:
yield Batch(items)
class BatchedSerializer(Serializer):
"""
Serializes a stream of objects in batches by calling its wrapped
Serializer with streams of objects.
"""
UNLIMITED_BATCH_SIZE = -1
def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
self.serializer = serializer
self.batchSize = batchSize
def _batched(self, iterator):
if self.batchSize == self.UNLIMITED_BATCH_SIZE:
yield list(iterator)
else:
items = []
count = 0
for item in iterator:
items.append(item)
count += 1
if count == self.batchSize:
yield items
items = []
count = 0
if items:
yield items
def dump_stream(self, iterator, stream):
if isinstance(iterator, basestring):
iterator = [iterator]
self.serializer.dump_stream(self._batched(iterator), stream)
def load_stream(self, stream):
return chain.from_iterable(self._load_stream_without_unbatching(stream))
def _load_stream_without_unbatching(self, stream):
return self.serializer.load_stream(stream)
def __eq__(self, other):
return isinstance(other, BatchedSerializer) and \
other.serializer == self.serializer
def __str__(self):
return "BatchedSerializer<%s>" % str(self.serializer)
def dump_pickle(obj):
return cPickle.dumps(obj, 2)
class CartesianDeserializer(FramedSerializer):
"""
Deserializes the JavaRDD cartesian() of two PythonRDDs.
"""
def __init__(self, key_ser, val_ser):
self.key_ser = key_ser
self.val_ser = val_ser
def load_stream(self, stream):
key_stream = self.key_ser._load_stream_without_unbatching(stream)
val_stream = self.val_ser._load_stream_without_unbatching(stream)
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
for (keys, vals) in izip(key_stream, val_stream):
keys = keys if key_is_batched else [keys]
vals = vals if val_is_batched else [vals]
for pair in product(keys, vals):
yield pair
def __eq__(self, other):
return isinstance(other, CartesianDeserializer) and \
self.key_ser == other.key_ser and self.val_ser == other.val_ser
def __str__(self):
return "CartesianDeserializer<%s, %s>" % \
(str(self.key_ser), str(self.val_ser))
load_pickle = cPickle.loads
class NoOpSerializer(FramedSerializer):
def _loads(self, obj): return obj
def _dumps(self, obj): return obj
class PickleSerializer(FramedSerializer):
"""
Serializes objects using Python's cPickle serializer:
http://docs.python.org/2/library/pickle.html
This serializer supports nearly any Python object, but may
not be as fast as more specialized serializers.
"""
def _dumps(self, obj): return cPickle.dumps(obj, 2)
_loads = cPickle.loads
class MarshalSerializer(FramedSerializer):
"""
Serializes objects using Python's Marshal serializer:
http://docs.python.org/2/library/marshal.html
This serializer is faster than PickleSerializer but supports fewer datatypes.
"""
_dumps = marshal.dumps
_loads = marshal.loads
class MUTF8Deserializer(Serializer):
"""
Deserializes streams written by Java's DataOutputStream.writeUTF().
"""
def _loads(self, stream):
length = struct.unpack('>H', stream.read(2))[0]
return stream.read(length).decode('utf8')
def load_stream(self, stream):
while True:
try:
yield self._loads(stream)
except struct.error:
return
except EOFError:
return
def read_long(stream):
@ -90,43 +305,4 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
def read_mutf8(stream):
"""
Read a string written with Java's DataOutputStream.writeUTF() method.
"""
length = struct.unpack('>H', stream.read(2))[0]
return stream.read(length).decode('utf8')
def read_with_length(stream):
length = read_int(stream)
obj = stream.read(length)
if obj == "":
raise EOFError
return obj
def read_from_pickle_file(stream):
try:
while True:
obj = load_pickle(read_with_length(stream))
if type(obj) == Batch: # We don't care about inheritance
for item in obj.items:
yield item
else:
yield obj
except EOFError:
return
def read_pairs_from_pickle_file(stream):
try:
while True:
a = load_pickle(read_with_length(stream))
b = load_pickle(read_with_length(stream))
yield (a, b)
except EOFError:
return
stream.write(obj)

View file

@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
flatMappedRDD._jrdd_deserializer)
self.assertEquals([1, 2, 3, 4], recovered.collect())

View file

@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
SpecialLengths, read_mutf8, read_pairs_from_pickle_file
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
pickleSer = PickleSerializer()
mutf8_deserializer = MUTF8Deserializer()
def load_obj(infile):
return load_pickle(standard_b64decode(infile.readline().strip()))
decoded = standard_b64decode(infile.readline().strip())
return pickleSer._loads(decoded)
def report_times(outfile, boot, init, finish):
@ -53,7 +57,7 @@ def main(infile, outfile):
return
# fetch name of workdir
spark_files_dir = read_mutf8(infile)
spark_files_dir = mutf8_deserializer._loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@ -61,31 +65,24 @@ def main(infile, outfile):
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
value = pickleSer._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
filename = mutf8_deserializer._loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
# now load function
# Load this stage's function and serializer:
func = load_obj(infile)
bypassSerializer = load_obj(infile)
stageInputIsPairs = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
deserializer = load_obj(infile)
serializer = load_obj(infile)
init_time = time.time()
if stageInputIsPairs:
iterator = read_pairs_from_pickle_file(infile)
else:
iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
write_with_length(dumps(obj), outfile)
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
@ -96,7 +93,7 @@ def main(infile, outfile):
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
for (aid, accum) in _accumulatorRegistry.items():
write_with_length(dump_pickle((aid, accum._value)), outfile)
pickleSer._write_with_length((aid, accum._value), outfile)
if __name__ == '__main__':

View file

@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "-m doctest pyspark/broadcast.py"
run_test "-m doctest pyspark/accumulators.py"
run_test "-m doctest pyspark/serializers.py"
run_test "pyspark/tests.py"
if [[ $FAILED != 0 ]]; then