Fix sys.path bug in PySpark SparkContext.addPyFile
This commit is contained in:
parent
7b9e96c992
commit
35168d9c89
|
@ -215,8 +215,6 @@ class SparkContext(object):
|
||||||
"""
|
"""
|
||||||
self.addFile(path)
|
self.addFile(path)
|
||||||
filename = path.split("/")[-1]
|
filename = path.split("/")[-1]
|
||||||
os.environ["PYTHONPATH"] = \
|
|
||||||
"%s:%s" % (filename, os.environ["PYTHONPATH"])
|
|
||||||
|
|
||||||
def setCheckpointDir(self, dirName, useExisting=False):
|
def setCheckpointDir(self, dirName, useExisting=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -9,21 +9,32 @@ import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from pyspark.context import SparkContext
|
from pyspark.context import SparkContext
|
||||||
|
from pyspark.java_gateway import SPARK_HOME
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpoint(unittest.TestCase):
|
class PySparkTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
|
class_name = self.__class__.__name__
|
||||||
self.checkpointDir = NamedTemporaryFile(delete=False)
|
self.sc = SparkContext('local[4]', class_name , batchSize=2)
|
||||||
os.unlink(self.checkpointDir.name)
|
|
||||||
self.sc.setCheckpointDir(self.checkpointDir.name)
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.sc.stop()
|
self.sc.stop()
|
||||||
# To avoid Akka rebinding to the same port, since it doesn't unbind
|
# To avoid Akka rebinding to the same port, since it doesn't unbind
|
||||||
# immediately on shutdown
|
# immediately on shutdown
|
||||||
self.sc.jvm.System.clearProperty("spark.master.port")
|
self.sc.jvm.System.clearProperty("spark.master.port")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpoint(PySparkTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
PySparkTestCase.setUp(self)
|
||||||
|
self.checkpointDir = NamedTemporaryFile(delete=False)
|
||||||
|
os.unlink(self.checkpointDir.name)
|
||||||
|
self.sc.setCheckpointDir(self.checkpointDir.name)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
PySparkTestCase.tearDown(self)
|
||||||
shutil.rmtree(self.checkpointDir.name)
|
shutil.rmtree(self.checkpointDir.name)
|
||||||
|
|
||||||
def test_basic_checkpointing(self):
|
def test_basic_checkpointing(self):
|
||||||
|
@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
|
||||||
self.assertEquals([1, 2, 3, 4], recovered.collect())
|
self.assertEquals([1, 2, 3, 4], recovered.collect())
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddFile(PySparkTestCase):
|
||||||
|
|
||||||
|
def test_add_py_file(self):
|
||||||
|
# To ensure that we're actually testing addPyFile's effects, check that
|
||||||
|
# this job fails due to `userlibrary` not being on the Python path:
|
||||||
|
def func(x):
|
||||||
|
from userlibrary import UserClass
|
||||||
|
return UserClass().hello()
|
||||||
|
self.assertRaises(Exception,
|
||||||
|
self.sc.parallelize(range(2)).map(func).first)
|
||||||
|
# Add the file, so the job should now succeed:
|
||||||
|
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
|
||||||
|
self.sc.addPyFile(path)
|
||||||
|
res = self.sc.parallelize(range(2)).map(func).first()
|
||||||
|
self.assertEqual("Hello World!", res)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -26,6 +26,7 @@ def main():
|
||||||
split_index = read_int(sys.stdin)
|
split_index = read_int(sys.stdin)
|
||||||
spark_files_dir = load_pickle(read_with_length(sys.stdin))
|
spark_files_dir = load_pickle(read_with_length(sys.stdin))
|
||||||
SparkFiles._root_directory = spark_files_dir
|
SparkFiles._root_directory = spark_files_dir
|
||||||
|
sys.path.append(spark_files_dir)
|
||||||
num_broadcast_variables = read_int(sys.stdin)
|
num_broadcast_variables = read_int(sys.stdin)
|
||||||
for _ in range(num_broadcast_variables):
|
for _ in range(num_broadcast_variables):
|
||||||
bid = read_long(sys.stdin)
|
bid = read_long(sys.stdin)
|
||||||
|
|
7
python/test_support/userlibrary.py
Executable file
7
python/test_support/userlibrary.py
Executable file
|
@ -0,0 +1,7 @@
|
||||||
|
"""
|
||||||
|
Used to test shipping of code depenencies with SparkContext.addPyFile().
|
||||||
|
"""
|
||||||
|
|
||||||
|
class UserClass(object):
|
||||||
|
def hello(self):
|
||||||
|
return "Hello World!"
|
Loading…
Reference in a new issue