Add custom serializer support to PySpark.
For now, this only adds MarshalSerializer, but it lays the groundwork for other supporting custom serializers. Many of these mechanisms can also be used to support deserialization of different data formats sent by Java, such as data encoded by MsgPack. This also fixes a bug in SparkContext.union().
This commit is contained in:
parent
7d68a81a8e
commit
cbb7f04aef
|
@ -221,18 +221,6 @@ private[spark] object PythonRDD {
|
|||
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
|
||||
}
|
||||
|
||||
def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
|
||||
val s = elem.getBytes("UTF-8")
|
||||
val length = 2 + 1 + 4 + s.length + 1
|
||||
dOut.writeInt(length)
|
||||
dOut.writeByte(Pickle.PROTO)
|
||||
dOut.writeByte(Pickle.TWO)
|
||||
dOut.write(Pickle.BINUNICODE)
|
||||
dOut.writeInt(Integer.reverseBytes(s.length))
|
||||
dOut.write(s)
|
||||
dOut.writeByte(Pickle.STOP)
|
||||
}
|
||||
|
||||
def writeToStream(elem: Any, dataOut: DataOutputStream) {
|
||||
elem match {
|
||||
case bytes: Array[Byte] =>
|
||||
|
@ -244,9 +232,7 @@ private[spark] object PythonRDD {
|
|||
dataOut.writeInt(pair._2.length)
|
||||
dataOut.write(pair._2)
|
||||
case str: String =>
|
||||
// Until we've implemented full custom serializer support, we need to return
|
||||
// strings as Pickles to properly support union() and cartesian():
|
||||
writeStringAsPickle(str, dataOut)
|
||||
dataOut.writeUTF(str)
|
||||
case other =>
|
||||
throw new SparkException("Unexpected element type " + other.getClass)
|
||||
}
|
||||
|
@ -271,13 +257,6 @@ private[spark] object PythonRDD {
|
|||
}
|
||||
}
|
||||
|
||||
private object Pickle {
|
||||
val PROTO: Byte = 0x80.toByte
|
||||
val TWO: Byte = 0x02.toByte
|
||||
val BINUNICODE: Byte = 'X'
|
||||
val STOP: Byte = '.'
|
||||
}
|
||||
|
||||
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
|
||||
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
|
||||
}
|
||||
|
|
|
@ -32,6 +32,6 @@ target: docs/
|
|||
|
||||
private: no
|
||||
|
||||
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
|
||||
exclude: pyspark.cloudpickle pyspark.worker pyspark.join
|
||||
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
|
||||
pyspark.rddsampler pyspark.daemon
|
||||
|
|
|
@ -90,9 +90,11 @@ import struct
|
|||
import SocketServer
|
||||
import threading
|
||||
from pyspark.cloudpickle import CloudPickler
|
||||
from pyspark.serializers import read_int, read_with_length, load_pickle
|
||||
from pyspark.serializers import read_int, PickleSerializer
|
||||
|
||||
|
||||
pickleSer = PickleSerializer()
|
||||
|
||||
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
|
||||
# the local accumulator updates back to the driver program at the end of a task.
|
||||
_accumulatorRegistry = {}
|
||||
|
@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
|
|||
from pyspark.accumulators import _accumulatorRegistry
|
||||
num_updates = read_int(self.rfile)
|
||||
for _ in range(num_updates):
|
||||
(aid, update) = load_pickle(read_with_length(self.rfile))
|
||||
(aid, update) = pickleSer._read_with_length(self.rfile)
|
||||
_accumulatorRegistry[aid] += update
|
||||
# Write a byte in acknowledgement
|
||||
self.wfile.write(struct.pack("!b", 1))
|
||||
|
|
|
@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
|
|||
from pyspark.broadcast import Broadcast
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.java_gateway import launch_gateway
|
||||
from pyspark.serializers import dump_pickle, write_with_length, batched
|
||||
from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.rdd import RDD
|
||||
|
||||
|
@ -51,7 +51,7 @@ class SparkContext(object):
|
|||
|
||||
|
||||
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
|
||||
environment=None, batchSize=1024):
|
||||
environment=None, batchSize=1024, serializer=PickleSerializer()):
|
||||
"""
|
||||
Create a new SparkContext.
|
||||
|
||||
|
@ -67,6 +67,7 @@ class SparkContext(object):
|
|||
@param batchSize: The number of Python objects represented as a single
|
||||
Java object. Set 1 to disable batching or -1 to use an
|
||||
unlimited batch size.
|
||||
@param serializer: The serializer for RDDs.
|
||||
|
||||
|
||||
>>> from pyspark.context import SparkContext
|
||||
|
@ -83,7 +84,13 @@ class SparkContext(object):
|
|||
self.jobName = jobName
|
||||
self.sparkHome = sparkHome or None # None becomes null in Py4J
|
||||
self.environment = environment or {}
|
||||
self.batchSize = batchSize # -1 represents a unlimited batch size
|
||||
self._batchSize = batchSize # -1 represents an unlimited batch size
|
||||
self._unbatched_serializer = serializer
|
||||
if batchSize == 1:
|
||||
self.serializer = self._unbatched_serializer
|
||||
else:
|
||||
self.serializer = BatchedSerializer(self._unbatched_serializer,
|
||||
batchSize)
|
||||
|
||||
# Create the Java SparkContext through Py4J
|
||||
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
|
||||
|
@ -184,15 +191,17 @@ class SparkContext(object):
|
|||
# Make sure we distribute data evenly if it's smaller than self.batchSize
|
||||
if "__len__" not in dir(c):
|
||||
c = list(c) # Make it a list so we can compute its length
|
||||
batchSize = min(len(c) // numSlices, self.batchSize)
|
||||
batchSize = min(len(c) // numSlices, self._batchSize)
|
||||
if batchSize > 1:
|
||||
c = batched(c, batchSize)
|
||||
for x in c:
|
||||
write_with_length(dump_pickle(x), tempFile)
|
||||
serializer = BatchedSerializer(self._unbatched_serializer,
|
||||
batchSize)
|
||||
else:
|
||||
serializer = self._unbatched_serializer
|
||||
serializer.dump_stream(c, tempFile)
|
||||
tempFile.close()
|
||||
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
|
||||
jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
|
||||
return RDD(jrdd, self)
|
||||
return RDD(jrdd, self, serializer)
|
||||
|
||||
def textFile(self, name, minSplits=None):
|
||||
"""
|
||||
|
@ -201,21 +210,39 @@ class SparkContext(object):
|
|||
RDD of Strings.
|
||||
"""
|
||||
minSplits = minSplits or min(self.defaultParallelism, 2)
|
||||
jrdd = self._jsc.textFile(name, minSplits)
|
||||
return RDD(jrdd, self)
|
||||
return RDD(self._jsc.textFile(name, minSplits), self,
|
||||
MUTF8Deserializer())
|
||||
|
||||
def _checkpointFile(self, name):
|
||||
def _checkpointFile(self, name, input_deserializer):
|
||||
jrdd = self._jsc.checkpointFile(name)
|
||||
return RDD(jrdd, self)
|
||||
return RDD(jrdd, self, input_deserializer)
|
||||
|
||||
def union(self, rdds):
|
||||
"""
|
||||
Build the union of a list of RDDs.
|
||||
|
||||
This supports unions() of RDDs with different serialized formats,
|
||||
although this forces them to be reserialized using the default
|
||||
serializer:
|
||||
|
||||
>>> path = os.path.join(tempdir, "union-text.txt")
|
||||
>>> with open(path, "w") as testFile:
|
||||
... testFile.write("Hello")
|
||||
>>> textFile = sc.textFile(path)
|
||||
>>> textFile.collect()
|
||||
[u'Hello']
|
||||
>>> parallelized = sc.parallelize(["World!"])
|
||||
>>> sorted(sc.union([textFile, parallelized]).collect())
|
||||
[u'Hello', 'World!']
|
||||
"""
|
||||
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
|
||||
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
|
||||
rdds = [x._reserialize() for x in rdds]
|
||||
first = rdds[0]._jrdd
|
||||
rest = [x._jrdd for x in rdds[1:]]
|
||||
rest = ListConverter().convert(rest, self.gateway._gateway_client)
|
||||
return RDD(self._jsc.union(first, rest), self)
|
||||
rest = ListConverter().convert(rest, self._gateway._gateway_client)
|
||||
return RDD(self._jsc.union(first, rest), self,
|
||||
rdds[0]._jrdd_deserializer)
|
||||
|
||||
def broadcast(self, value):
|
||||
"""
|
||||
|
@ -223,7 +250,9 @@ class SparkContext(object):
|
|||
object for reading it in distributed functions. The variable will be
|
||||
sent to each cluster only once.
|
||||
"""
|
||||
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
|
||||
pickleSer = PickleSerializer()
|
||||
pickled = pickleSer._dumps(value)
|
||||
jbroadcast = self._jsc.broadcast(bytearray(pickled))
|
||||
return Broadcast(jbroadcast.id(), value, jbroadcast,
|
||||
self._pickled_broadcast_vars)
|
||||
|
||||
|
@ -235,7 +264,7 @@ class SparkContext(object):
|
|||
and floating-point numbers if you do not provide one. For other types,
|
||||
a custom AccumulatorParam can be used.
|
||||
"""
|
||||
if accum_param == None:
|
||||
if accum_param is None:
|
||||
if isinstance(value, int):
|
||||
accum_param = accumulators.INT_ACCUMULATOR_PARAM
|
||||
elif isinstance(value, float):
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
from base64 import standard_b64encode as b64enc
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from itertools import chain, ifilter, imap, product
|
||||
from itertools import chain, ifilter, imap
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
|
@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile
|
|||
from threading import Thread
|
||||
|
||||
from pyspark import cloudpickle
|
||||
from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
|
||||
read_from_pickle_file, pack_long
|
||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||
BatchedSerializer, pack_long
|
||||
from pyspark.join import python_join, python_left_outer_join, \
|
||||
python_right_outer_join, python_cogroup
|
||||
from pyspark.statcounter import StatCounter
|
||||
|
@ -48,13 +48,12 @@ class RDD(object):
|
|||
operated on in parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, jrdd, ctx):
|
||||
def __init__(self, jrdd, ctx, jrdd_deserializer):
|
||||
self._jrdd = jrdd
|
||||
self.is_cached = False
|
||||
self.is_checkpointed = False
|
||||
self.ctx = ctx
|
||||
self._partitionFunc = None
|
||||
self._stage_input_is_pairs = False
|
||||
self._jrdd_deserializer = jrdd_deserializer
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
|
@ -248,7 +247,23 @@ class RDD(object):
|
|||
>>> rdd.union(rdd).collect()
|
||||
[1, 1, 2, 3, 1, 1, 2, 3]
|
||||
"""
|
||||
return RDD(self._jrdd.union(other._jrdd), self.ctx)
|
||||
if self._jrdd_deserializer == other._jrdd_deserializer:
|
||||
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
|
||||
self._jrdd_deserializer)
|
||||
return rdd
|
||||
else:
|
||||
# These RDDs contain data in different serialized formats, so we
|
||||
# must normalize them to the default serializer.
|
||||
self_copy = self._reserialize()
|
||||
other_copy = other._reserialize()
|
||||
return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
|
||||
self.ctx.serializer)
|
||||
|
||||
def _reserialize(self):
|
||||
if self._jrdd_deserializer == self.ctx.serializer:
|
||||
return self
|
||||
else:
|
||||
return self.map(lambda x: x, preservesPartitioning=True)
|
||||
|
||||
def __add__(self, other):
|
||||
"""
|
||||
|
@ -335,18 +350,9 @@ class RDD(object):
|
|||
[(1, 1), (1, 2), (2, 1), (2, 2)]
|
||||
"""
|
||||
# Due to batching, we can't use the Java cartesian method.
|
||||
java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
|
||||
def unpack_batches(pair):
|
||||
(x, y) = pair
|
||||
if type(x) == Batch or type(y) == Batch:
|
||||
xs = x.items if type(x) == Batch else [x]
|
||||
ys = y.items if type(y) == Batch else [y]
|
||||
for pair in product(xs, ys):
|
||||
yield pair
|
||||
else:
|
||||
yield pair
|
||||
java_cartesian._stage_input_is_pairs = True
|
||||
return java_cartesian.flatMap(unpack_batches)
|
||||
deserializer = CartesianDeserializer(self._jrdd_deserializer,
|
||||
other._jrdd_deserializer)
|
||||
return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
|
||||
|
||||
def groupBy(self, f, numPartitions=None):
|
||||
"""
|
||||
|
@ -405,7 +411,7 @@ class RDD(object):
|
|||
self.ctx._writeToFile(iterator, tempFile.name)
|
||||
# Read the data into Python and deserialize it:
|
||||
with open(tempFile.name, 'rb') as tempFile:
|
||||
for item in read_from_pickle_file(tempFile):
|
||||
for item in self._jrdd_deserializer.load_stream(tempFile):
|
||||
yield item
|
||||
os.unlink(tempFile.name)
|
||||
|
||||
|
@ -573,7 +579,7 @@ class RDD(object):
|
|||
items = []
|
||||
for partition in range(mapped._jrdd.splits().size()):
|
||||
iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
|
||||
items.extend(self._collect_iterator_through_file(iterator))
|
||||
items.extend(mapped._collect_iterator_through_file(iterator))
|
||||
if len(items) >= num:
|
||||
break
|
||||
return items[:num]
|
||||
|
@ -737,6 +743,7 @@ class RDD(object):
|
|||
# Transferring O(n) objects to Java is too expensive. Instead, we'll
|
||||
# form the hash buckets in Python, transferring O(numPartitions) objects
|
||||
# to Java. Each object is a (splitNumber, [objects]) pair.
|
||||
outputSerializer = self.ctx._unbatched_serializer
|
||||
def add_shuffle_key(split, iterator):
|
||||
|
||||
buckets = defaultdict(list)
|
||||
|
@ -745,14 +752,14 @@ class RDD(object):
|
|||
buckets[partitionFunc(k) % numPartitions].append((k, v))
|
||||
for (split, items) in buckets.iteritems():
|
||||
yield pack_long(split)
|
||||
yield dump_pickle(Batch(items))
|
||||
yield outputSerializer._dumps(items)
|
||||
keyed = PipelinedRDD(self, add_shuffle_key)
|
||||
keyed._bypass_serializer = True
|
||||
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
|
||||
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
|
||||
id(partitionFunc))
|
||||
jrdd = pairRDD.partitionBy(partitioner).values()
|
||||
rdd = RDD(jrdd, self.ctx)
|
||||
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
|
||||
# This is required so that id(partitionFunc) remains unique, even if
|
||||
# partitionFunc is a lambda:
|
||||
rdd._partitionFunc = partitionFunc
|
||||
|
@ -789,7 +796,8 @@ class RDD(object):
|
|||
numPartitions = self.ctx.defaultParallelism
|
||||
def combineLocally(iterator):
|
||||
combiners = {}
|
||||
for (k, v) in iterator:
|
||||
for x in iterator:
|
||||
(k, v) = x
|
||||
if k not in combiners:
|
||||
combiners[k] = createCombiner(v)
|
||||
else:
|
||||
|
@ -931,38 +939,38 @@ class PipelinedRDD(RDD):
|
|||
20
|
||||
"""
|
||||
def __init__(self, prev, func, preservesPartitioning=False):
|
||||
if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
|
||||
if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
|
||||
# This transformation is the first in its stage:
|
||||
self.func = func
|
||||
self.preservesPartitioning = preservesPartitioning
|
||||
self._prev_jrdd = prev._jrdd
|
||||
self._prev_jrdd_deserializer = prev._jrdd_deserializer
|
||||
else:
|
||||
prev_func = prev.func
|
||||
def pipeline_func(split, iterator):
|
||||
return func(split, prev_func(split, iterator))
|
||||
self.func = pipeline_func
|
||||
self.preservesPartitioning = \
|
||||
prev.preservesPartitioning and preservesPartitioning
|
||||
self._prev_jrdd = prev._prev_jrdd
|
||||
else:
|
||||
self.func = func
|
||||
self.preservesPartitioning = preservesPartitioning
|
||||
self._prev_jrdd = prev._jrdd
|
||||
self._stage_input_is_pairs = prev._stage_input_is_pairs
|
||||
self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
|
||||
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
|
||||
self.is_cached = False
|
||||
self.is_checkpointed = False
|
||||
self.ctx = prev.ctx
|
||||
self.prev = prev
|
||||
self._jrdd_val = None
|
||||
self._jrdd_deserializer = self.ctx.serializer
|
||||
self._bypass_serializer = False
|
||||
|
||||
@property
|
||||
def _jrdd(self):
|
||||
if self._jrdd_val:
|
||||
return self._jrdd_val
|
||||
func = self.func
|
||||
if not self._bypass_serializer and self.ctx.batchSize != 1:
|
||||
oldfunc = self.func
|
||||
batchSize = self.ctx.batchSize
|
||||
def batched_func(split, iterator):
|
||||
return batched(oldfunc(split, iterator), batchSize)
|
||||
func = batched_func
|
||||
cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
|
||||
if self._bypass_serializer:
|
||||
serializer = NoOpSerializer()
|
||||
else:
|
||||
serializer = self.ctx.serializer
|
||||
cmds = [self.func, self._prev_jrdd_deserializer, serializer]
|
||||
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
|
||||
broadcast_vars = ListConverter().convert(
|
||||
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
|
||||
|
|
|
@ -15,8 +15,58 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import struct
|
||||
"""
|
||||
PySpark supports custom serializers for transferring data; this can improve
|
||||
performance.
|
||||
|
||||
By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
|
||||
C{cPickle} serializer, which can serialize nearly any Python object.
|
||||
Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
|
||||
faster.
|
||||
|
||||
The serializer is chosen when creating L{SparkContext}:
|
||||
|
||||
>>> from pyspark.context import SparkContext
|
||||
>>> from pyspark.serializers import MarshalSerializer
|
||||
>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
|
||||
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
|
||||
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
|
||||
>>> sc.stop()
|
||||
|
||||
By default, PySpark serialize objects in batches; the batch size can be
|
||||
controlled through SparkContext's C{batchSize} parameter
|
||||
(the default size is 1024 objects):
|
||||
|
||||
>>> sc = SparkContext('local', 'test', batchSize=2)
|
||||
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
|
||||
|
||||
Behind the scenes, this creates a JavaRDD with four partitions, each of
|
||||
which contains two batches of two objects:
|
||||
|
||||
>>> rdd.glom().collect()
|
||||
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
|
||||
>>> rdd._jrdd.count()
|
||||
8L
|
||||
>>> sc.stop()
|
||||
|
||||
A batch size of -1 uses an unlimited batch size, and a size of 1 disables
|
||||
batching:
|
||||
|
||||
>>> sc = SparkContext('local', 'test', batchSize=1)
|
||||
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
|
||||
>>> rdd.glom().collect()
|
||||
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
|
||||
>>> rdd._jrdd.count()
|
||||
16L
|
||||
"""
|
||||
|
||||
import cPickle
|
||||
from itertools import chain, izip, product
|
||||
import marshal
|
||||
import struct
|
||||
|
||||
|
||||
__all__ = ["PickleSerializer", "MarshalSerializer"]
|
||||
|
||||
|
||||
class SpecialLengths(object):
|
||||
|
@ -25,41 +75,206 @@ class SpecialLengths(object):
|
|||
TIMING_DATA = -3
|
||||
|
||||
|
||||
class Batch(object):
|
||||
class Serializer(object):
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
"""
|
||||
Serialize an iterator of objects to the output stream.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Return an iterator of deserialized objects from the input stream.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _load_stream_without_unbatching(self, stream):
|
||||
return self.load_stream(stream)
|
||||
|
||||
# Note: our notion of "equality" is that output generated by
|
||||
# equal serializers can be deserialized using the same serializer.
|
||||
|
||||
# This default implementation handles the simple cases;
|
||||
# subclasses should override __eq__ as appropriate.
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class FramedSerializer(Serializer):
|
||||
"""
|
||||
Used to store multiple RDD entries as a single Java object.
|
||||
|
||||
This relieves us from having to explicitly track whether an RDD
|
||||
is stored as batches of objects and avoids problems when processing
|
||||
the union() of batched and unbatched RDDs (e.g. the union() of textFile()
|
||||
with another RDD).
|
||||
Serializer that writes objects as a stream of (length, data) pairs,
|
||||
where C{length} is a 32-bit integer and data is C{length} bytes.
|
||||
"""
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
for obj in iterator:
|
||||
self._write_with_length(obj, stream)
|
||||
|
||||
def load_stream(self, stream):
|
||||
while True:
|
||||
try:
|
||||
yield self._read_with_length(stream)
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
def _write_with_length(self, obj, stream):
|
||||
serialized = self._dumps(obj)
|
||||
write_int(len(serialized), stream)
|
||||
stream.write(serialized)
|
||||
|
||||
def _read_with_length(self, stream):
|
||||
length = read_int(stream)
|
||||
obj = stream.read(length)
|
||||
if obj == "":
|
||||
raise EOFError
|
||||
return self._loads(obj)
|
||||
|
||||
def _dumps(self, obj):
|
||||
"""
|
||||
Serialize an object into a byte array.
|
||||
When batching is used, this will be called with an array of objects.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _loads(self, obj):
|
||||
"""
|
||||
Deserialize an object from a byte array.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def batched(iterator, batchSize):
|
||||
if batchSize == -1: # unlimited batch size
|
||||
yield Batch(list(iterator))
|
||||
else:
|
||||
items = []
|
||||
count = 0
|
||||
for item in iterator:
|
||||
items.append(item)
|
||||
count += 1
|
||||
if count == batchSize:
|
||||
yield Batch(items)
|
||||
items = []
|
||||
count = 0
|
||||
if items:
|
||||
yield Batch(items)
|
||||
class BatchedSerializer(Serializer):
|
||||
"""
|
||||
Serializes a stream of objects in batches by calling its wrapped
|
||||
Serializer with streams of objects.
|
||||
"""
|
||||
|
||||
UNLIMITED_BATCH_SIZE = -1
|
||||
|
||||
def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
|
||||
self.serializer = serializer
|
||||
self.batchSize = batchSize
|
||||
|
||||
def _batched(self, iterator):
|
||||
if self.batchSize == self.UNLIMITED_BATCH_SIZE:
|
||||
yield list(iterator)
|
||||
else:
|
||||
items = []
|
||||
count = 0
|
||||
for item in iterator:
|
||||
items.append(item)
|
||||
count += 1
|
||||
if count == self.batchSize:
|
||||
yield items
|
||||
items = []
|
||||
count = 0
|
||||
if items:
|
||||
yield items
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
if isinstance(iterator, basestring):
|
||||
iterator = [iterator]
|
||||
self.serializer.dump_stream(self._batched(iterator), stream)
|
||||
|
||||
def load_stream(self, stream):
|
||||
return chain.from_iterable(self._load_stream_without_unbatching(stream))
|
||||
|
||||
def _load_stream_without_unbatching(self, stream):
|
||||
return self.serializer.load_stream(stream)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, BatchedSerializer) and \
|
||||
other.serializer == self.serializer
|
||||
|
||||
def __str__(self):
|
||||
return "BatchedSerializer<%s>" % str(self.serializer)
|
||||
|
||||
|
||||
def dump_pickle(obj):
|
||||
return cPickle.dumps(obj, 2)
|
||||
class CartesianDeserializer(FramedSerializer):
|
||||
"""
|
||||
Deserializes the JavaRDD cartesian() of two PythonRDDs.
|
||||
"""
|
||||
|
||||
def __init__(self, key_ser, val_ser):
|
||||
self.key_ser = key_ser
|
||||
self.val_ser = val_ser
|
||||
|
||||
def load_stream(self, stream):
|
||||
key_stream = self.key_ser._load_stream_without_unbatching(stream)
|
||||
val_stream = self.val_ser._load_stream_without_unbatching(stream)
|
||||
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
|
||||
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
|
||||
for (keys, vals) in izip(key_stream, val_stream):
|
||||
keys = keys if key_is_batched else [keys]
|
||||
vals = vals if val_is_batched else [vals]
|
||||
for pair in product(keys, vals):
|
||||
yield pair
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, CartesianDeserializer) and \
|
||||
self.key_ser == other.key_ser and self.val_ser == other.val_ser
|
||||
|
||||
def __str__(self):
|
||||
return "CartesianDeserializer<%s, %s>" % \
|
||||
(str(self.key_ser), str(self.val_ser))
|
||||
|
||||
|
||||
load_pickle = cPickle.loads
|
||||
class NoOpSerializer(FramedSerializer):
|
||||
|
||||
def _loads(self, obj): return obj
|
||||
def _dumps(self, obj): return obj
|
||||
|
||||
|
||||
class PickleSerializer(FramedSerializer):
|
||||
"""
|
||||
Serializes objects using Python's cPickle serializer:
|
||||
|
||||
http://docs.python.org/2/library/pickle.html
|
||||
|
||||
This serializer supports nearly any Python object, but may
|
||||
not be as fast as more specialized serializers.
|
||||
"""
|
||||
|
||||
def _dumps(self, obj): return cPickle.dumps(obj, 2)
|
||||
_loads = cPickle.loads
|
||||
|
||||
|
||||
class MarshalSerializer(FramedSerializer):
|
||||
"""
|
||||
Serializes objects using Python's Marshal serializer:
|
||||
|
||||
http://docs.python.org/2/library/marshal.html
|
||||
|
||||
This serializer is faster than PickleSerializer but supports fewer datatypes.
|
||||
"""
|
||||
|
||||
_dumps = marshal.dumps
|
||||
_loads = marshal.loads
|
||||
|
||||
|
||||
class MUTF8Deserializer(Serializer):
|
||||
"""
|
||||
Deserializes streams written by Java's DataOutputStream.writeUTF().
|
||||
"""
|
||||
|
||||
def _loads(self, stream):
|
||||
length = struct.unpack('>H', stream.read(2))[0]
|
||||
return stream.read(length).decode('utf8')
|
||||
|
||||
def load_stream(self, stream):
|
||||
while True:
|
||||
try:
|
||||
yield self._loads(stream)
|
||||
except struct.error:
|
||||
return
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
|
||||
def read_long(stream):
|
||||
|
@ -90,43 +305,4 @@ def write_int(value, stream):
|
|||
|
||||
def write_with_length(obj, stream):
|
||||
write_int(len(obj), stream)
|
||||
stream.write(obj)
|
||||
|
||||
|
||||
def read_mutf8(stream):
|
||||
"""
|
||||
Read a string written with Java's DataOutputStream.writeUTF() method.
|
||||
"""
|
||||
length = struct.unpack('>H', stream.read(2))[0]
|
||||
return stream.read(length).decode('utf8')
|
||||
|
||||
|
||||
def read_with_length(stream):
|
||||
length = read_int(stream)
|
||||
obj = stream.read(length)
|
||||
if obj == "":
|
||||
raise EOFError
|
||||
return obj
|
||||
|
||||
|
||||
def read_from_pickle_file(stream):
|
||||
try:
|
||||
while True:
|
||||
obj = load_pickle(read_with_length(stream))
|
||||
if type(obj) == Batch: # We don't care about inheritance
|
||||
for item in obj.items:
|
||||
yield item
|
||||
else:
|
||||
yield obj
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
|
||||
def read_pairs_from_pickle_file(stream):
|
||||
try:
|
||||
while True:
|
||||
a = load_pickle(read_with_length(stream))
|
||||
b = load_pickle(read_with_length(stream))
|
||||
yield (a, b)
|
||||
except EOFError:
|
||||
return
|
||||
stream.write(obj)
|
|
@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
|
|||
time.sleep(1) # 1 second
|
||||
|
||||
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
|
||||
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
|
||||
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
|
||||
flatMappedRDD._jrdd_deserializer)
|
||||
self.assertEquals([1, 2, 3, 4], recovered.collect())
|
||||
|
||||
|
||||
|
|
|
@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry
|
|||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||
from pyspark.cloudpickle import CloudPickler
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.serializers import write_with_length, read_with_length, write_int, \
|
||||
read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
|
||||
SpecialLengths, read_mutf8, read_pairs_from_pickle_file
|
||||
from pyspark.serializers import write_with_length, write_int, read_long, \
|
||||
write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
|
||||
|
||||
|
||||
pickleSer = PickleSerializer()
|
||||
mutf8_deserializer = MUTF8Deserializer()
|
||||
|
||||
|
||||
def load_obj(infile):
|
||||
return load_pickle(standard_b64decode(infile.readline().strip()))
|
||||
decoded = standard_b64decode(infile.readline().strip())
|
||||
return pickleSer._loads(decoded)
|
||||
|
||||
|
||||
def report_times(outfile, boot, init, finish):
|
||||
|
@ -53,7 +57,7 @@ def main(infile, outfile):
|
|||
return
|
||||
|
||||
# fetch name of workdir
|
||||
spark_files_dir = read_mutf8(infile)
|
||||
spark_files_dir = mutf8_deserializer._loads(infile)
|
||||
SparkFiles._root_directory = spark_files_dir
|
||||
SparkFiles._is_running_on_worker = True
|
||||
|
||||
|
@ -61,31 +65,24 @@ def main(infile, outfile):
|
|||
num_broadcast_variables = read_int(infile)
|
||||
for _ in range(num_broadcast_variables):
|
||||
bid = read_long(infile)
|
||||
value = read_with_length(infile)
|
||||
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
|
||||
value = pickleSer._read_with_length(infile)
|
||||
_broadcastRegistry[bid] = Broadcast(bid, value)
|
||||
|
||||
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
|
||||
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
|
||||
num_python_includes = read_int(infile)
|
||||
for _ in range(num_python_includes):
|
||||
sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
|
||||
filename = mutf8_deserializer._loads(infile)
|
||||
sys.path.append(os.path.join(spark_files_dir, filename))
|
||||
|
||||
# now load function
|
||||
# Load this stage's function and serializer:
|
||||
func = load_obj(infile)
|
||||
bypassSerializer = load_obj(infile)
|
||||
stageInputIsPairs = load_obj(infile)
|
||||
if bypassSerializer:
|
||||
dumps = lambda x: x
|
||||
else:
|
||||
dumps = dump_pickle
|
||||
deserializer = load_obj(infile)
|
||||
serializer = load_obj(infile)
|
||||
init_time = time.time()
|
||||
if stageInputIsPairs:
|
||||
iterator = read_pairs_from_pickle_file(infile)
|
||||
else:
|
||||
iterator = read_from_pickle_file(infile)
|
||||
try:
|
||||
for obj in func(split_index, iterator):
|
||||
write_with_length(dumps(obj), outfile)
|
||||
iterator = deserializer.load_stream(infile)
|
||||
serializer.dump_stream(func(split_index, iterator), outfile)
|
||||
except Exception as e:
|
||||
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
|
||||
write_with_length(traceback.format_exc(), outfile)
|
||||
|
@ -96,7 +93,7 @@ def main(infile, outfile):
|
|||
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
|
||||
write_int(len(_accumulatorRegistry), outfile)
|
||||
for (aid, accum) in _accumulatorRegistry.items():
|
||||
write_with_length(dump_pickle((aid, accum._value)), outfile)
|
||||
pickleSer._write_with_length((aid, accum._value), outfile)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
|
|||
run_test "pyspark/context.py"
|
||||
run_test "-m doctest pyspark/broadcast.py"
|
||||
run_test "-m doctest pyspark/accumulators.py"
|
||||
run_test "-m doctest pyspark/serializers.py"
|
||||
run_test "pyspark/tests.py"
|
||||
|
||||
if [[ $FAILED != 0 ]]; then
|
||||
|
|
Loading…
Reference in a new issue