Add broadcast variables to Python API.
This commit is contained in:
parent
65e8406029
commit
f79a1e4d2a
|
@ -7,14 +7,13 @@ import scala.collection.JavaConversions._
|
||||||
import scala.io.Source
|
import scala.io.Source
|
||||||
import spark._
|
import spark._
|
||||||
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
||||||
import scala.{collection, Some}
|
import broadcast.Broadcast
|
||||||
import collection.parallel.mutable
|
|
||||||
import scala.collection
|
import scala.collection
|
||||||
import scala.Some
|
|
||||||
|
|
||||||
trait PythonRDDBase {
|
trait PythonRDDBase {
|
||||||
def compute[T](split: Split, envVars: Map[String, String],
|
def compute[T](split: Split, envVars: Map[String, String],
|
||||||
command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = {
|
command: Seq[String], parent: RDD[T], pythonExec: String,
|
||||||
|
broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = {
|
||||||
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
|
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
|
||||||
|
|
||||||
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py"))
|
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py"))
|
||||||
|
@ -42,11 +41,18 @@ trait PythonRDDBase {
|
||||||
override def run() {
|
override def run() {
|
||||||
SparkEnv.set(env)
|
SparkEnv.set(env)
|
||||||
val out = new PrintWriter(proc.getOutputStream)
|
val out = new PrintWriter(proc.getOutputStream)
|
||||||
|
val dOut = new DataOutputStream(proc.getOutputStream)
|
||||||
|
out.println(broadcastVars.length)
|
||||||
|
for (broadcast <- broadcastVars) {
|
||||||
|
out.print(broadcast.uuid.toString)
|
||||||
|
dOut.writeInt(broadcast.value.length)
|
||||||
|
dOut.write(broadcast.value)
|
||||||
|
dOut.flush()
|
||||||
|
}
|
||||||
for (elem <- command) {
|
for (elem <- command) {
|
||||||
out.println(elem)
|
out.println(elem)
|
||||||
}
|
}
|
||||||
out.flush()
|
out.flush()
|
||||||
val dOut = new DataOutputStream(proc.getOutputStream)
|
|
||||||
for (elem <- parent.iterator(split)) {
|
for (elem <- parent.iterator(split)) {
|
||||||
if (elem.isInstanceOf[Array[Byte]]) {
|
if (elem.isInstanceOf[Array[Byte]]) {
|
||||||
val arr = elem.asInstanceOf[Array[Byte]]
|
val arr = elem.asInstanceOf[Array[Byte]]
|
||||||
|
@ -121,16 +127,17 @@ trait PythonRDDBase {
|
||||||
|
|
||||||
class PythonRDD[T: ClassManifest](
|
class PythonRDD[T: ClassManifest](
|
||||||
parent: RDD[T], command: Seq[String], envVars: Map[String, String],
|
parent: RDD[T], command: Seq[String], envVars: Map[String, String],
|
||||||
preservePartitoning: Boolean, pythonExec: String)
|
preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
|
||||||
extends RDD[Array[Byte]](parent.context) with PythonRDDBase {
|
extends RDD[Array[Byte]](parent.context) with PythonRDDBase {
|
||||||
|
|
||||||
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) =
|
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
|
||||||
this(parent, command, Map(), preservePartitoning, pythonExec)
|
pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
|
||||||
|
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
|
||||||
|
|
||||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||||
// using a standard StringTokenizer (i.e. by spaces)
|
// using a standard StringTokenizer (i.e. by spaces)
|
||||||
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) =
|
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
|
||||||
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec)
|
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
|
||||||
|
|
||||||
override def splits = parent.splits
|
override def splits = parent.splits
|
||||||
|
|
||||||
|
@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest](
|
||||||
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
||||||
|
|
||||||
override def compute(split: Split): Iterator[Array[Byte]] =
|
override def compute(split: Split): Iterator[Array[Byte]] =
|
||||||
compute(split, envVars, command, parent, pythonExec)
|
compute(split, envVars, command, parent, pythonExec, broadcastVars)
|
||||||
|
|
||||||
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
|
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
class PythonPairRDD[T: ClassManifest] (
|
class PythonPairRDD[T: ClassManifest] (
|
||||||
parent: RDD[T], command: Seq[String], envVars: Map[String, String],
|
parent: RDD[T], command: Seq[String], envVars: Map[String, String],
|
||||||
preservePartitoning: Boolean, pythonExec: String)
|
preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
|
||||||
extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase {
|
extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase {
|
||||||
|
|
||||||
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) =
|
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
|
||||||
this(parent, command, Map(), preservePartitoning, pythonExec)
|
pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
|
||||||
|
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
|
||||||
|
|
||||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||||
// using a standard StringTokenizer (i.e. by spaces)
|
// using a standard StringTokenizer (i.e. by spaces)
|
||||||
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) =
|
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String,
|
||||||
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec)
|
broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
|
||||||
|
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
|
||||||
|
|
||||||
override def splits = parent.splits
|
override def splits = parent.splits
|
||||||
|
|
||||||
|
@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] (
|
||||||
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
||||||
|
|
||||||
override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = {
|
override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = {
|
||||||
compute(split, envVars, command, parent, pythonExec).grouped(2).map {
|
compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map {
|
||||||
case Seq(a, b) => (a, b)
|
case Seq(a, b) => (a, b)
|
||||||
case x => throw new Exception("PythonPairRDD: unexpected value: " + x)
|
case x => throw new Exception("PythonPairRDD: unexpected value: " + x)
|
||||||
}
|
}
|
||||||
|
|
46
pyspark/pyspark/broadcast.py
Normal file
46
pyspark/pyspark/broadcast.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
"""
|
||||||
|
>>> from pyspark.context import SparkContext
|
||||||
|
>>> sc = SparkContext('local', 'test')
|
||||||
|
>>> b = sc.broadcast([1, 2, 3, 4, 5])
|
||||||
|
>>> b.value
|
||||||
|
[1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
>>> from pyspark.broadcast import _broadcastRegistry
|
||||||
|
>>> _broadcastRegistry[b.uuid] = b
|
||||||
|
>>> from cPickle import dumps, loads
|
||||||
|
>>> loads(dumps(b)).value
|
||||||
|
[1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
|
||||||
|
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
|
||||||
|
"""
|
||||||
|
# Holds broadcasted data received from Java, keyed by UUID.
|
||||||
|
_broadcastRegistry = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _from_uuid(uuid):
|
||||||
|
from pyspark.broadcast import _broadcastRegistry
|
||||||
|
if uuid not in _broadcastRegistry:
|
||||||
|
raise Exception("Broadcast variable '%s' not loaded!" % uuid)
|
||||||
|
return _broadcastRegistry[uuid]
|
||||||
|
|
||||||
|
|
||||||
|
class Broadcast(object):
|
||||||
|
def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
|
||||||
|
self.value = value
|
||||||
|
self.uuid = uuid
|
||||||
|
self._jbroadcast = java_broadcast
|
||||||
|
self._pickle_registry = pickle_registry
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
self._pickle_registry.add(self)
|
||||||
|
return (_from_uuid, (self.uuid, ))
|
||||||
|
|
||||||
|
|
||||||
|
def _test():
|
||||||
|
import doctest
|
||||||
|
doctest.testmod()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test()
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import atexit
|
import atexit
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
from pyspark.broadcast import Broadcast
|
||||||
from pyspark.java_gateway import launch_gateway
|
from pyspark.java_gateway import launch_gateway
|
||||||
from pyspark.serializers import PickleSerializer, dumps
|
from pyspark.serializers import PickleSerializer, dumps
|
||||||
from pyspark.rdd import RDD
|
from pyspark.rdd import RDD
|
||||||
|
@ -24,6 +25,11 @@ class SparkContext(object):
|
||||||
self.defaultParallelism = \
|
self.defaultParallelism = \
|
||||||
defaultParallelism or self._jsc.sc().defaultParallelism()
|
defaultParallelism or self._jsc.sc().defaultParallelism()
|
||||||
self.pythonExec = pythonExec
|
self.pythonExec = pythonExec
|
||||||
|
# Broadcast's __reduce__ method stores Broadcast instances here.
|
||||||
|
# This allows other code to determine which Broadcast instances have
|
||||||
|
# been pickled, so it can determine which Java broadcast objects to
|
||||||
|
# send.
|
||||||
|
self._pickled_broadcast_vars = set()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self._jsc:
|
if self._jsc:
|
||||||
|
@ -52,7 +58,12 @@ class SparkContext(object):
|
||||||
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
|
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
|
||||||
return RDD(jrdd, self)
|
return RDD(jrdd, self)
|
||||||
|
|
||||||
def textFile(self, name, numSlices=None):
|
def textFile(self, name, minSplits=None):
|
||||||
numSlices = numSlices or self.defaultParallelism
|
minSplits = minSplits or min(self.defaultParallelism, 2)
|
||||||
jrdd = self._jsc.textFile(name, numSlices)
|
jrdd = self._jsc.textFile(name, minSplits)
|
||||||
return RDD(jrdd, self)
|
return RDD(jrdd, self)
|
||||||
|
|
||||||
|
def broadcast(self, value):
|
||||||
|
jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value)))
|
||||||
|
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
|
||||||
|
self._pickled_broadcast_vars)
|
||||||
|
|
|
@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer
|
||||||
from pyspark.join import python_join, python_left_outer_join, \
|
from pyspark.join import python_join, python_left_outer_join, \
|
||||||
python_right_outer_join, python_cogroup
|
python_right_outer_join, python_cogroup
|
||||||
|
|
||||||
|
from py4j.java_collections import ListConverter
|
||||||
|
|
||||||
|
|
||||||
class RDD(object):
|
class RDD(object):
|
||||||
|
|
||||||
|
@ -15,11 +17,15 @@ class RDD(object):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_pipe_command(cls, command, functions):
|
def _get_pipe_command(cls, ctx, command, functions):
|
||||||
worker_args = [command]
|
worker_args = [command]
|
||||||
for f in functions:
|
for f in functions:
|
||||||
worker_args.append(b64enc(cloudpickle.dumps(f)))
|
worker_args.append(b64enc(cloudpickle.dumps(f)))
|
||||||
return " ".join(worker_args)
|
broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
|
||||||
|
broadcast_vars = ListConverter().convert(broadcast_vars,
|
||||||
|
ctx.gateway._gateway_client)
|
||||||
|
ctx._pickled_broadcast_vars.clear()
|
||||||
|
return (" ".join(worker_args), broadcast_vars)
|
||||||
|
|
||||||
def cache(self):
|
def cache(self):
|
||||||
self.is_cached = True
|
self.is_cached = True
|
||||||
|
@ -52,9 +58,10 @@ class RDD(object):
|
||||||
|
|
||||||
def _pipe(self, functions, command):
|
def _pipe(self, functions, command):
|
||||||
class_manifest = self._jrdd.classManifest()
|
class_manifest = self._jrdd.classManifest()
|
||||||
pipe_command = RDD._get_pipe_command(command, functions)
|
(pipe_command, broadcast_vars) = \
|
||||||
|
RDD._get_pipe_command(self.ctx, command, functions)
|
||||||
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
|
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
|
||||||
False, self.ctx.pythonExec, class_manifest)
|
False, self.ctx.pythonExec, broadcast_vars, class_manifest)
|
||||||
return python_rdd.asJavaRDD()
|
return python_rdd.asJavaRDD()
|
||||||
|
|
||||||
def distinct(self):
|
def distinct(self):
|
||||||
|
@ -249,10 +256,12 @@ class RDD(object):
|
||||||
def shuffle(self, numSplits, hashFunc=hash):
|
def shuffle(self, numSplits, hashFunc=hash):
|
||||||
if numSplits is None:
|
if numSplits is None:
|
||||||
numSplits = self.ctx.defaultParallelism
|
numSplits = self.ctx.defaultParallelism
|
||||||
pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc])
|
(pipe_command, broadcast_vars) = \
|
||||||
|
RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc])
|
||||||
class_manifest = self._jrdd.classManifest()
|
class_manifest = self._jrdd.classManifest()
|
||||||
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
|
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
|
||||||
pipe_command, False, self.ctx.pythonExec, class_manifest)
|
pipe_command, False, self.ctx.pythonExec, broadcast_vars,
|
||||||
|
class_manifest)
|
||||||
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
|
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
|
||||||
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
|
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
|
||||||
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
|
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
|
||||||
|
@ -360,12 +369,12 @@ class PipelinedRDD(RDD):
|
||||||
@property
|
@property
|
||||||
def _jrdd(self):
|
def _jrdd(self):
|
||||||
if not self._jrdd_val:
|
if not self._jrdd_val:
|
||||||
funcs = [self.func]
|
(pipe_command, broadcast_vars) = \
|
||||||
pipe_command = RDD._get_pipe_command("pipeline", funcs)
|
RDD._get_pipe_command(self.ctx, "pipeline", [self.func])
|
||||||
class_manifest = self._prev_jrdd.classManifest()
|
class_manifest = self._prev_jrdd.classManifest()
|
||||||
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
|
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||||
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
|
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
|
||||||
class_manifest)
|
broadcast_vars, class_manifest)
|
||||||
self._jrdd_val = python_rdd.asJavaRDD()
|
self._jrdd_val = python_rdd.asJavaRDD()
|
||||||
return self._jrdd_val
|
return self._jrdd_val
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import sys
|
||||||
from base64 import standard_b64decode
|
from base64 import standard_b64decode
|
||||||
# CloudPickler needs to be imported so that depicklers are registered using the
|
# CloudPickler needs to be imported so that depicklers are registered using the
|
||||||
# copy_reg module.
|
# copy_reg module.
|
||||||
|
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||||
from pyspark.cloudpickle import CloudPickler
|
from pyspark.cloudpickle import CloudPickler
|
||||||
from pyspark.serializers import dumps, loads, PickleSerializer
|
from pyspark.serializers import dumps, loads, PickleSerializer
|
||||||
import cPickle
|
import cPickle
|
||||||
|
@ -63,6 +64,11 @@ def do_shuffle_map_step():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
num_broadcast_variables = int(sys.stdin.readline().strip())
|
||||||
|
for _ in range(num_broadcast_variables):
|
||||||
|
uuid = sys.stdin.read(36)
|
||||||
|
value = loads(sys.stdin)
|
||||||
|
_broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value))
|
||||||
command = sys.stdin.readline().strip()
|
command = sys.stdin.readline().strip()
|
||||||
if command == "pipeline":
|
if command == "pipeline":
|
||||||
do_pipeline()
|
do_pipeline()
|
||||||
|
|
Loading…
Reference in a new issue