Fix sys.path bug in PySpark SparkContext.addPyFile
This commit is contained in:
parent
7b9e96c992
commit
35168d9c89
|
@ -215,8 +215,6 @@ class SparkContext(object):
|
|||
"""
|
||||
self.addFile(path)
|
||||
filename = path.split("/")[-1]
|
||||
os.environ["PYTHONPATH"] = \
|
||||
"%s:%s" % (filename, os.environ["PYTHONPATH"])
|
||||
|
||||
def setCheckpointDir(self, dirName, useExisting=False):
|
||||
"""
|
||||
|
|
|
@ -9,21 +9,32 @@ import time
|
|||
import unittest
|
||||
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.java_gateway import SPARK_HOME
|
||||
|
||||
|
||||
class TestCheckpoint(unittest.TestCase):
|
||||
class PySparkTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
|
||||
self.checkpointDir = NamedTemporaryFile(delete=False)
|
||||
os.unlink(self.checkpointDir.name)
|
||||
self.sc.setCheckpointDir(self.checkpointDir.name)
|
||||
class_name = self.__class__.__name__
|
||||
self.sc = SparkContext('local[4]', class_name , batchSize=2)
|
||||
|
||||
def tearDown(self):
|
||||
self.sc.stop()
|
||||
# To avoid Akka rebinding to the same port, since it doesn't unbind
|
||||
# immediately on shutdown
|
||||
self.sc.jvm.System.clearProperty("spark.master.port")
|
||||
|
||||
|
||||
class TestCheckpoint(PySparkTestCase):
|
||||
|
||||
def setUp(self):
|
||||
PySparkTestCase.setUp(self)
|
||||
self.checkpointDir = NamedTemporaryFile(delete=False)
|
||||
os.unlink(self.checkpointDir.name)
|
||||
self.sc.setCheckpointDir(self.checkpointDir.name)
|
||||
|
||||
def tearDown(self):
|
||||
PySparkTestCase.tearDown(self)
|
||||
shutil.rmtree(self.checkpointDir.name)
|
||||
|
||||
def test_basic_checkpointing(self):
|
||||
|
@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
|
|||
self.assertEquals([1, 2, 3, 4], recovered.collect())
|
||||
|
||||
|
||||
class TestAddFile(PySparkTestCase):
|
||||
|
||||
def test_add_py_file(self):
|
||||
# To ensure that we're actually testing addPyFile's effects, check that
|
||||
# this job fails due to `userlibrary` not being on the Python path:
|
||||
def func(x):
|
||||
from userlibrary import UserClass
|
||||
return UserClass().hello()
|
||||
self.assertRaises(Exception,
|
||||
self.sc.parallelize(range(2)).map(func).first)
|
||||
# Add the file, so the job should now succeed:
|
||||
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
|
||||
self.sc.addPyFile(path)
|
||||
res = self.sc.parallelize(range(2)).map(func).first()
|
||||
self.assertEqual("Hello World!", res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -26,6 +26,7 @@ def main():
|
|||
split_index = read_int(sys.stdin)
|
||||
spark_files_dir = load_pickle(read_with_length(sys.stdin))
|
||||
SparkFiles._root_directory = spark_files_dir
|
||||
sys.path.append(spark_files_dir)
|
||||
num_broadcast_variables = read_int(sys.stdin)
|
||||
for _ in range(num_broadcast_variables):
|
||||
bid = read_long(sys.stdin)
|
||||
|
|
7
python/test_support/userlibrary.py
Executable file
7
python/test_support/userlibrary.py
Executable file
|
@ -0,0 +1,7 @@
|
|||
"""
|
||||
Used to test shipping of code depenencies with SparkContext.addPyFile().
|
||||
"""
|
||||
|
||||
class UserClass(object):
|
||||
def hello(self):
|
||||
return "Hello World!"
|
Loading…
Reference in a new issue