Update Python API for v0.6.0 compatibility.
This commit is contained in:
parent
e21eb6e00d
commit
52989c8a2c
|
@ -5,11 +5,15 @@ import java.io._
|
||||||
import scala.collection.Map
|
import scala.collection.Map
|
||||||
import scala.collection.JavaConversions._
|
import scala.collection.JavaConversions._
|
||||||
import scala.io.Source
|
import scala.io.Source
|
||||||
import spark._
|
|
||||||
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
||||||
import broadcast.Broadcast
|
import spark.broadcast.Broadcast
|
||||||
import scala.collection
|
import spark.SparkEnv
|
||||||
import java.nio.charset.Charset
|
import spark.Split
|
||||||
|
import spark.RDD
|
||||||
|
import spark.OneToOneDependency
|
||||||
|
import spark.rdd.PipedRDD
|
||||||
|
|
||||||
|
|
||||||
trait PythonRDDBase {
|
trait PythonRDDBase {
|
||||||
def compute[T](split: Split, envVars: Map[String, String],
|
def compute[T](split: Split, envVars: Map[String, String],
|
||||||
|
@ -43,9 +47,9 @@ trait PythonRDDBase {
|
||||||
SparkEnv.set(env)
|
SparkEnv.set(env)
|
||||||
val out = new PrintWriter(proc.getOutputStream)
|
val out = new PrintWriter(proc.getOutputStream)
|
||||||
val dOut = new DataOutputStream(proc.getOutputStream)
|
val dOut = new DataOutputStream(proc.getOutputStream)
|
||||||
out.println(broadcastVars.length)
|
dOut.writeInt(broadcastVars.length)
|
||||||
for (broadcast <- broadcastVars) {
|
for (broadcast <- broadcastVars) {
|
||||||
out.print(broadcast.uuid.toString)
|
dOut.writeLong(broadcast.id)
|
||||||
dOut.writeInt(broadcast.value.length)
|
dOut.writeInt(broadcast.value.length)
|
||||||
dOut.write(broadcast.value)
|
dOut.write(broadcast.value)
|
||||||
dOut.flush()
|
dOut.flush()
|
||||||
|
|
|
@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
|
||||||
|
|
||||||
import spark._
|
import spark._
|
||||||
|
|
||||||
abstract class Broadcast[T](id: Long) extends Serializable {
|
abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
|
||||||
def value: T
|
def value: T
|
||||||
|
|
||||||
// We cannot have an abstract readObject here due to some weird issues with
|
// We cannot have an abstract readObject here due to some weird issues with
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
[1, 2, 3, 4, 5]
|
[1, 2, 3, 4, 5]
|
||||||
|
|
||||||
>>> from pyspark.broadcast import _broadcastRegistry
|
>>> from pyspark.broadcast import _broadcastRegistry
|
||||||
>>> _broadcastRegistry[b.uuid] = b
|
>>> _broadcastRegistry[b.bid] = b
|
||||||
>>> from cPickle import dumps, loads
|
>>> from cPickle import dumps, loads
|
||||||
>>> loads(dumps(b)).value
|
>>> loads(dumps(b)).value
|
||||||
[1, 2, 3, 4, 5]
|
[1, 2, 3, 4, 5]
|
||||||
|
@ -14,27 +14,27 @@
|
||||||
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
|
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
|
||||||
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
|
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
|
||||||
"""
|
"""
|
||||||
# Holds broadcasted data received from Java, keyed by UUID.
|
# Holds broadcasted data received from Java, keyed by its id.
|
||||||
_broadcastRegistry = {}
|
_broadcastRegistry = {}
|
||||||
|
|
||||||
|
|
||||||
def _from_uuid(uuid):
|
def _from_id(bid):
|
||||||
from pyspark.broadcast import _broadcastRegistry
|
from pyspark.broadcast import _broadcastRegistry
|
||||||
if uuid not in _broadcastRegistry:
|
if bid not in _broadcastRegistry:
|
||||||
raise Exception("Broadcast variable '%s' not loaded!" % uuid)
|
raise Exception("Broadcast variable '%s' not loaded!" % bid)
|
||||||
return _broadcastRegistry[uuid]
|
return _broadcastRegistry[bid]
|
||||||
|
|
||||||
|
|
||||||
class Broadcast(object):
|
class Broadcast(object):
|
||||||
def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
|
def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.uuid = uuid
|
self.bid = bid
|
||||||
self._jbroadcast = java_broadcast
|
self._jbroadcast = java_broadcast
|
||||||
self._pickle_registry = pickle_registry
|
self._pickle_registry = pickle_registry
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
self._pickle_registry.add(self)
|
self._pickle_registry.add(self)
|
||||||
return (_from_uuid, (self.uuid, ))
|
return (_from_id, (self.bid, ))
|
||||||
|
|
||||||
|
|
||||||
def _test():
|
def _test():
|
||||||
|
|
|
@ -66,5 +66,5 @@ class SparkContext(object):
|
||||||
|
|
||||||
def broadcast(self, value):
|
def broadcast(self, value):
|
||||||
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
|
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
|
||||||
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
|
return Broadcast(jbroadcast.id(), value, jbroadcast,
|
||||||
self._pickled_broadcast_vars)
|
self._pickled_broadcast_vars)
|
||||||
|
|
|
@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"]
|
||||||
|
|
||||||
|
|
||||||
assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \
|
assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \
|
||||||
"/spark-core-assembly-*-SNAPSHOT.jar")[0]
|
"/spark-core-assembly-*.jar")[0]
|
||||||
|
# TODO: what if multiple assembly jars are found?
|
||||||
|
|
||||||
|
|
||||||
def launch_gateway():
|
def launch_gateway():
|
||||||
|
|
|
@ -9,16 +9,26 @@ def dump_pickle(obj):
|
||||||
load_pickle = cPickle.loads
|
load_pickle = cPickle.loads
|
||||||
|
|
||||||
|
|
||||||
|
def read_long(stream):
|
||||||
|
length = stream.read(8)
|
||||||
|
if length == "":
|
||||||
|
raise EOFError
|
||||||
|
return struct.unpack("!q", length)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def read_int(stream):
|
||||||
|
length = stream.read(4)
|
||||||
|
if length == "":
|
||||||
|
raise EOFError
|
||||||
|
return struct.unpack("!i", length)[0]
|
||||||
|
|
||||||
def write_with_length(obj, stream):
|
def write_with_length(obj, stream):
|
||||||
stream.write(struct.pack("!i", len(obj)))
|
stream.write(struct.pack("!i", len(obj)))
|
||||||
stream.write(obj)
|
stream.write(obj)
|
||||||
|
|
||||||
|
|
||||||
def read_with_length(stream):
|
def read_with_length(stream):
|
||||||
length = stream.read(4)
|
length = read_int(stream)
|
||||||
if length == "":
|
|
||||||
raise EOFError
|
|
||||||
length = struct.unpack("!i", length)[0]
|
|
||||||
obj = stream.read(length)
|
obj = stream.read(length)
|
||||||
if obj == "":
|
if obj == "":
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
|
|
@ -8,7 +8,7 @@ from base64 import standard_b64decode
|
||||||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||||
from pyspark.cloudpickle import CloudPickler
|
from pyspark.cloudpickle import CloudPickler
|
||||||
from pyspark.serializers import write_with_length, read_with_length, \
|
from pyspark.serializers import write_with_length, read_with_length, \
|
||||||
dump_pickle, load_pickle
|
read_long, read_int, dump_pickle, load_pickle
|
||||||
|
|
||||||
|
|
||||||
# Redirect stdout to stderr so that users must return values from functions.
|
# Redirect stdout to stderr so that users must return values from functions.
|
||||||
|
@ -29,11 +29,11 @@ def read_input():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
num_broadcast_variables = int(sys.stdin.readline().strip())
|
num_broadcast_variables = read_int(sys.stdin)
|
||||||
for _ in range(num_broadcast_variables):
|
for _ in range(num_broadcast_variables):
|
||||||
uuid = sys.stdin.read(36)
|
bid = read_long(sys.stdin)
|
||||||
value = read_with_length(sys.stdin)
|
value = read_with_length(sys.stdin)
|
||||||
_broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value))
|
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
|
||||||
func = load_obj()
|
func = load_obj()
|
||||||
bypassSerializer = load_obj()
|
bypassSerializer = load_obj()
|
||||||
if bypassSerializer:
|
if bypassSerializer:
|
||||||
|
|
Loading…
Reference in a new issue