Update Python API for v0.6.0 compatibility.

This commit is contained in:
Josh Rosen 2012-10-19 10:24:49 -07:00
parent e21eb6e00d
commit 52989c8a2c
7 changed files with 42 additions and 27 deletions

View file

@ -5,11 +5,15 @@ import java.io._
import scala.collection.Map import scala.collection.Map
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.io.Source import scala.io.Source
import spark._
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import broadcast.Broadcast import spark.broadcast.Broadcast
import scala.collection import spark.SparkEnv
import java.nio.charset.Charset import spark.Split
import spark.RDD
import spark.OneToOneDependency
import spark.rdd.PipedRDD
trait PythonRDDBase { trait PythonRDDBase {
def compute[T](split: Split, envVars: Map[String, String], def compute[T](split: Split, envVars: Map[String, String],
@ -43,9 +47,9 @@ trait PythonRDDBase {
SparkEnv.set(env) SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream) val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream)
out.println(broadcastVars.length) dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) { for (broadcast <- broadcastVars) {
out.print(broadcast.uuid.toString) dOut.writeLong(broadcast.id)
dOut.writeInt(broadcast.value.length) dOut.writeInt(broadcast.value.length)
dOut.write(broadcast.value) dOut.write(broadcast.value)
dOut.flush() dOut.flush()

View file

@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
import spark._ import spark._
abstract class Broadcast[T](id: Long) extends Serializable { abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T def value: T
// We cannot have an abstract readObject here due to some weird issues with // We cannot have an abstract readObject here due to some weird issues with

View file

@ -6,7 +6,7 @@
[1, 2, 3, 4, 5] [1, 2, 3, 4, 5]
>>> from pyspark.broadcast import _broadcastRegistry >>> from pyspark.broadcast import _broadcastRegistry
>>> _broadcastRegistry[b.uuid] = b >>> _broadcastRegistry[b.bid] = b
>>> from cPickle import dumps, loads >>> from cPickle import dumps, loads
>>> loads(dumps(b)).value >>> loads(dumps(b)).value
[1, 2, 3, 4, 5] [1, 2, 3, 4, 5]
@ -14,27 +14,27 @@
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
""" """
# Holds broadcasted data received from Java, keyed by UUID. # Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {} _broadcastRegistry = {}
def _from_uuid(uuid): def _from_id(bid):
from pyspark.broadcast import _broadcastRegistry from pyspark.broadcast import _broadcastRegistry
if uuid not in _broadcastRegistry: if bid not in _broadcastRegistry:
raise Exception("Broadcast variable '%s' not loaded!" % uuid) raise Exception("Broadcast variable '%s' not loaded!" % bid)
return _broadcastRegistry[uuid] return _broadcastRegistry[bid]
class Broadcast(object): class Broadcast(object):
def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
self.value = value self.value = value
self.uuid = uuid self.bid = bid
self._jbroadcast = java_broadcast self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry self._pickle_registry = pickle_registry
def __reduce__(self): def __reduce__(self):
self._pickle_registry.add(self) self._pickle_registry.add(self)
return (_from_uuid, (self.uuid, )) return (_from_id, (self.bid, ))
def _test(): def _test():

View file

@ -66,5 +66,5 @@ class SparkContext(object):
def broadcast(self, value): def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars) self._pickled_broadcast_vars)

View file

@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"]
assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \
"/spark-core-assembly-*-SNAPSHOT.jar")[0] "/spark-core-assembly-*.jar")[0]
# TODO: what if multiple assembly jars are found?
def launch_gateway(): def launch_gateway():

View file

@ -9,16 +9,26 @@ def dump_pickle(obj):
load_pickle = cPickle.loads load_pickle = cPickle.loads
def read_long(stream):
length = stream.read(8)
if length == "":
raise EOFError
return struct.unpack("!q", length)[0]
def read_int(stream):
length = stream.read(4)
if length == "":
raise EOFError
return struct.unpack("!i", length)[0]
def write_with_length(obj, stream): def write_with_length(obj, stream):
stream.write(struct.pack("!i", len(obj))) stream.write(struct.pack("!i", len(obj)))
stream.write(obj) stream.write(obj)
def read_with_length(stream): def read_with_length(stream):
length = stream.read(4) length = read_int(stream)
if length == "":
raise EOFError
length = struct.unpack("!i", length)[0]
obj = stream.read(length) obj = stream.read(length)
if obj == "": if obj == "":
raise EOFError raise EOFError

View file

@ -8,7 +8,7 @@ from base64 import standard_b64decode
from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \ from pyspark.serializers import write_with_length, read_with_length, \
dump_pickle, load_pickle read_long, read_int, dump_pickle, load_pickle
# Redirect stdout to stderr so that users must return values from functions. # Redirect stdout to stderr so that users must return values from functions.
@ -29,11 +29,11 @@ def read_input():
def main(): def main():
num_broadcast_variables = int(sys.stdin.readline().strip()) num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables): for _ in range(num_broadcast_variables):
uuid = sys.stdin.read(36) bid = read_long(sys.stdin)
value = read_with_length(sys.stdin) value = read_with_length(sys.stdin)
_broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
func = load_obj() func = load_obj()
bypassSerializer = load_obj() bypassSerializer = load_obj()
if bypassSerializer: if bypassSerializer: