Fix sys.path bug in PySpark SparkContext.addPyFile

This commit is contained in:
Josh Rosen 2013-01-22 17:54:11 -08:00
parent 7b9e96c992
commit 35168d9c89
4 changed files with 41 additions and 7 deletions

View file

@ -215,8 +215,6 @@ class SparkContext(object):
"""
self.addFile(path)
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
def setCheckpointDir(self, dirName, useExisting=False):
"""

View file

@ -9,21 +9,32 @@ import time
import unittest
from pyspark.context import SparkContext
from pyspark.java_gateway import SPARK_HOME
class TestCheckpoint(unittest.TestCase):
class PySparkTestCase(unittest.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
self.checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self):
self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
class TestCheckpoint(PySparkTestCase):
def setUp(self):
PySparkTestCase.setUp(self)
self.checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
class TestAddFile(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this job fails due to `userlibrary` not being on the Python path:
def func(x):
from userlibrary import UserClass
return UserClass().hello()
self.assertRaises(Exception,
self.sc.parallelize(range(2)).map(func).first)
# Add the file, so the job should now succeed:
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addPyFile(path)
res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res)
if __name__ == "__main__":
unittest.main()

View file

@ -26,6 +26,7 @@ def main():
split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin))
SparkFiles._root_directory = spark_files_dir
sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)

View file

@ -0,0 +1,7 @@
"""
Used to test shipping of code depenencies with SparkContext.addPyFile().
"""
class UserClass(object):
def hello(self):
return "Hello World!"