Merge branch 'master' of https://github.com/mesos/spark into indexed_rdd

This commit is contained in:
Joseph E. Gonzalez 2013-08-19 13:05:59 -07:00
commit 0598c10eb1
6 changed files with 45 additions and 6 deletions

View file

@ -33,6 +33,7 @@ private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T], parent: RDD[T],
command: Seq[String], command: Seq[String],
envVars: JMap[String, String], envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean, preservePartitoning: Boolean,
pythonExec: String, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]], broadcastVars: JList[Broadcast[Array[Byte]]],
@ -44,10 +45,11 @@ private[spark] class PythonRDD[T: ClassManifest](
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String], def this(parent: RDD[T], command: String, envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean, pythonExec: String, preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]], broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) = accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
broadcastVars, accumulator) broadcastVars, accumulator)
override def getPartitions = parent.partitions override def getPartitions = parent.partitions
@ -79,6 +81,11 @@ private[spark] class PythonRDD[T: ClassManifest](
dataOut.writeInt(broadcast.value.length) dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value) dataOut.write(broadcast.value)
} }
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
for (f <- pythonIncludes) {
PythonRDD.writeAsPickle(f, dataOut)
}
dataOut.flush() dataOut.flush()
// Serialized user code // Serialized user code
for (elem <- command) { for (elem <- command) {

View file

@ -46,6 +46,7 @@ class SparkContext(object):
_next_accum_id = 0 _next_accum_id = 0
_active_spark_context = None _active_spark_context = None
_lock = Lock() _lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
def __init__(self, master, jobName, sparkHome=None, pyFiles=None, def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024): environment=None, batchSize=1024):
@ -103,11 +104,14 @@ class SparkContext(object):
# send. # send.
self._pickled_broadcast_vars = set() self._pickled_broadcast_vars = set()
SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
sys.path.append(root_dir)
# Deploy any code dependencies specified in the constructor # Deploy any code dependencies specified in the constructor
self._python_includes = list()
for path in (pyFiles or []): for path in (pyFiles or []):
self.addPyFile(path) self.addPyFile(path)
SparkFiles._sc = self
sys.path.append(SparkFiles.getRootDirectory())
# Create a temporary directory inside spark.local.dir: # Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.spark.Utils.getLocalDir() local_dir = self._jvm.spark.Utils.getLocalDir()
@ -257,7 +261,11 @@ class SparkContext(object):
HTTP, HTTPS or FTP URI. HTTP, HTTPS or FTP URI.
""" """
self.addFile(path) self.addFile(path)
filename = path.split("/")[-1] (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
self._python_includes.append(filename)
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
def setCheckpointDir(self, dirName, useExisting=False): def setCheckpointDir(self, dirName, useExisting=False):
""" """

View file

@ -758,8 +758,10 @@ class PipelinedRDD(RDD):
class_manifest = self._prev_jrdd.classManifest() class_manifest = self._prev_jrdd.classManifest()
env = MapConverter().convert(self.ctx.environment, env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client) self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest) broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD() self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val return self._jrdd_val

View file

@ -125,6 +125,17 @@ class TestAddFile(PySparkTestCase):
from userlibrary import UserClass from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello()) self.assertEqual("Hello World!", UserClass().hello())
def test_add_egg_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
def func():
from userlib import UserClass
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg")
self.sc.addPyFile(path)
from userlib import UserClass
self.assertEqual("Hello World from inside a package!", UserClass().hello())
class TestIO(PySparkTestCase): class TestIO(PySparkTestCase):

View file

@ -49,15 +49,26 @@ def main(infile, outfile):
split_index = read_int(infile) split_index = read_int(infile)
if split_index == -1: # for unit tests if split_index == -1: # for unit tests
return return
# fetch name of workdir
spark_files_dir = load_pickle(read_with_length(infile)) spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile) num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables): for _ in range(num_broadcast_variables):
bid = read_long(infile) bid = read_long(infile)
value = read_with_length(infile) value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
# now load function
func = load_obj(infile) func = load_obj(infile)
bypassSerializer = load_obj(infile) bypassSerializer = load_obj(infile)
if bypassSerializer: if bypassSerializer:

Binary file not shown.