Add broadcast variables to Python API.

This commit is contained in:
Josh Rosen 2012-08-25 13:59:01 -07:00
parent 65e8406029
commit f79a1e4d2a
5 changed files with 110 additions and 29 deletions

View file

@ -7,14 +7,13 @@ import scala.collection.JavaConversions._
import scala.io.Source import scala.io.Source
import spark._ import spark._
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import scala.{collection, Some} import broadcast.Broadcast
import collection.parallel.mutable
import scala.collection import scala.collection
import scala.Some
trait PythonRDDBase { trait PythonRDDBase {
def compute[T](split: Split, envVars: Map[String, String], def compute[T](split: Split, envVars: Map[String, String],
command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { command: Seq[String], parent: RDD[T], pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py"))
@ -42,11 +41,18 @@ trait PythonRDDBase {
override def run() { override def run() {
SparkEnv.set(env) SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream) val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
out.println(broadcastVars.length)
for (broadcast <- broadcastVars) {
out.print(broadcast.uuid.toString)
dOut.writeInt(broadcast.value.length)
dOut.write(broadcast.value)
dOut.flush()
}
for (elem <- command) { for (elem <- command) {
out.println(elem) out.println(elem)
} }
out.flush() out.flush()
val dOut = new DataOutputStream(proc.getOutputStream)
for (elem <- parent.iterator(split)) { for (elem <- parent.iterator(split)) {
if (elem.isInstanceOf[Array[Byte]]) { if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]] val arr = elem.asInstanceOf[Array[Byte]]
@ -121,16 +127,17 @@ trait PythonRDDBase {
class PythonRDD[T: ClassManifest]( class PythonRDD[T: ClassManifest](
parent: RDD[T], command: Seq[String], envVars: Map[String, String], parent: RDD[T], command: Seq[String], envVars: Map[String, String],
preservePartitoning: Boolean, pythonExec: String) preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[Array[Byte]](parent.context) with PythonRDDBase { extends RDD[Array[Byte]](parent.context) with PythonRDDBase {
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
this(parent, command, Map(), preservePartitoning, pythonExec) pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits override def splits = parent.splits
@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[Array[Byte]] = override def compute(split: Split): Iterator[Array[Byte]] =
compute(split, envVars, command, parent, pythonExec) compute(split, envVars, command, parent, pythonExec, broadcastVars)
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
} }
class PythonPairRDD[T: ClassManifest] ( class PythonPairRDD[T: ClassManifest] (
parent: RDD[T], command: Seq[String], envVars: Map[String, String], parent: RDD[T], command: Seq[String], envVars: Map[String, String],
preservePartitoning: Boolean, pythonExec: String) preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase {
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
this(parent, command, Map(), preservePartitoning, pythonExec) pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String,
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits override def splits = parent.splits
@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] (
override val partitioner = if (preservePartitoning) parent.partitioner else None override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = {
compute(split, envVars, command, parent, pythonExec).grouped(2).map { compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map {
case Seq(a, b) => (a, b) case Seq(a, b) => (a, b)
case x => throw new Exception("PythonPairRDD: unexpected value: " + x) case x => throw new Exception("PythonPairRDD: unexpected value: " + x)
} }

View file

@ -0,0 +1,46 @@
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> b = sc.broadcast([1, 2, 3, 4, 5])
>>> b.value
[1, 2, 3, 4, 5]
>>> from pyspark.broadcast import _broadcastRegistry
>>> _broadcastRegistry[b.uuid] = b
>>> from cPickle import dumps, loads
>>> loads(dumps(b)).value
[1, 2, 3, 4, 5]
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
"""
# Holds broadcasted data received from Java, keyed by UUID.
_broadcastRegistry = {}
def _from_uuid(uuid):
from pyspark.broadcast import _broadcastRegistry
if uuid not in _broadcastRegistry:
raise Exception("Broadcast variable '%s' not loaded!" % uuid)
return _broadcastRegistry[uuid]
class Broadcast(object):
def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
self.value = value
self.uuid = uuid
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
def __reduce__(self):
self._pickle_registry.add(self)
return (_from_uuid, (self.uuid, ))
def _test():
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()

View file

@ -2,6 +2,7 @@ import os
import atexit import atexit
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, dumps from pyspark.serializers import PickleSerializer, dumps
from pyspark.rdd import RDD from pyspark.rdd import RDD
@ -24,6 +25,11 @@ class SparkContext(object):
self.defaultParallelism = \ self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism() defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = pythonExec self.pythonExec = pythonExec
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
def __del__(self): def __del__(self):
if self._jsc: if self._jsc:
@ -52,7 +58,12 @@ class SparkContext(object):
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self) return RDD(jrdd, self)
def textFile(self, name, numSlices=None): def textFile(self, name, minSplits=None):
numSlices = numSlices or self.defaultParallelism minSplits = minSplits or min(self.defaultParallelism, 2)
jrdd = self._jsc.textFile(name, numSlices) jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self) return RDD(jrdd, self)
def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
self._pickled_broadcast_vars)

View file

@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer
from pyspark.join import python_join, python_left_outer_join, \ from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup python_right_outer_join, python_cogroup
from py4j.java_collections import ListConverter
class RDD(object): class RDD(object):
@ -15,11 +17,15 @@ class RDD(object):
self.ctx = ctx self.ctx = ctx
@classmethod @classmethod
def _get_pipe_command(cls, command, functions): def _get_pipe_command(cls, ctx, command, functions):
worker_args = [command] worker_args = [command]
for f in functions: for f in functions:
worker_args.append(b64enc(cloudpickle.dumps(f))) worker_args.append(b64enc(cloudpickle.dumps(f)))
return " ".join(worker_args) broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
broadcast_vars = ListConverter().convert(broadcast_vars,
ctx.gateway._gateway_client)
ctx._pickled_broadcast_vars.clear()
return (" ".join(worker_args), broadcast_vars)
def cache(self): def cache(self):
self.is_cached = True self.is_cached = True
@ -52,9 +58,10 @@ class RDD(object):
def _pipe(self, functions, command): def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest() class_manifest = self._jrdd.classManifest()
pipe_command = RDD._get_pipe_command(command, functions) (pipe_command, broadcast_vars) = \
RDD._get_pipe_command(self.ctx, command, functions)
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
False, self.ctx.pythonExec, class_manifest) False, self.ctx.pythonExec, broadcast_vars, class_manifest)
return python_rdd.asJavaRDD() return python_rdd.asJavaRDD()
def distinct(self): def distinct(self):
@ -249,10 +256,12 @@ class RDD(object):
def shuffle(self, numSplits, hashFunc=hash): def shuffle(self, numSplits, hashFunc=hash):
if numSplits is None: if numSplits is None:
numSplits = self.ctx.defaultParallelism numSplits = self.ctx.defaultParallelism
pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) (pipe_command, broadcast_vars) = \
RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc])
class_manifest = self._jrdd.classManifest() class_manifest = self._jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
pipe_command, False, self.ctx.pythonExec, class_manifest) pipe_command, False, self.ctx.pythonExec, broadcast_vars,
class_manifest)
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
@ -360,12 +369,12 @@ class PipelinedRDD(RDD):
@property @property
def _jrdd(self): def _jrdd(self):
if not self._jrdd_val: if not self._jrdd_val:
funcs = [self.func] (pipe_command, broadcast_vars) = \
pipe_command = RDD._get_pipe_command("pipeline", funcs) RDD._get_pipe_command(self.ctx, "pipeline", [self.func])
class_manifest = self._prev_jrdd.classManifest() class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec, pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
class_manifest) broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD() self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val return self._jrdd_val

View file

@ -5,6 +5,7 @@ import sys
from base64 import standard_b64decode from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the # CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module. # copy_reg module.
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import dumps, loads, PickleSerializer from pyspark.serializers import dumps, loads, PickleSerializer
import cPickle import cPickle
@ -63,6 +64,11 @@ def do_shuffle_map_step():
def main(): def main():
num_broadcast_variables = int(sys.stdin.readline().strip())
for _ in range(num_broadcast_variables):
uuid = sys.stdin.read(36)
value = loads(sys.stdin)
_broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value))
command = sys.stdin.readline().strip() command = sys.stdin.readline().strip()
if command == "pipeline": if command == "pipeline":
do_pipeline() do_pipeline()