[SPARK-3786] [PySpark] speedup tests

This patch try to speed up tests of PySpark, re-use the SparkContext in tests.py and mllib/tests.py to reduce the overhead of create SparkContext, remove some test cases, which did not make sense. It also improve the performance of some cases, such as MergerTests and SortTests.

before this patch:

real	21m27.320s
user	4m42.967s
sys	0m17.343s

after this patch:

real	9m47.541s
user	2m12.947s
sys	0m14.543s

It almost cut the time by half.

Author: Davies Liu <davies.liu@gmail.com>

Closes #2646 from davies/tests and squashes the following commits:

c54de60 [Davies Liu] revert change about memory limit
6a2a4b0 [Davies Liu] refactor of tests, speedup 100%
This commit is contained in:
Davies Liu 2014-10-06 14:07:53 -07:00 committed by Josh Rosen
parent 20ea54cc7a
commit 4f01265f7d
4 changed files with 82 additions and 91 deletions

View file

@ -32,7 +32,7 @@ else:
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.tests import PySparkTestCase
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
_have_scipy = False

View file

@ -396,7 +396,6 @@ class ExternalMerger(Merger):
for v in self.data.iteritems():
yield v
self.data.clear()
gc.collect()
# remove the merged partition
for j in range(self.spills):
@ -428,7 +427,7 @@ class ExternalMerger(Merger):
subdirs = [os.path.join(d, "parts", str(i))
for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
subdirs, self.scale * self.partitions)
subdirs, self.scale * self.partitions, self.partitions)
m.pdata = [{} for _ in range(self.partitions)]
limit = self._next_limit()
@ -486,7 +485,7 @@ class ExternalSorter(object):
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
batch = 100
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:

View file

@ -67,10 +67,10 @@ except:
SPARK_HOME = os.environ["SPARK_HOME"]
class TestMerger(unittest.TestCase):
class MergerTests(unittest.TestCase):
def setUp(self):
self.N = 1 << 16
self.N = 1 << 14
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@ -115,7 +115,7 @@ class TestMerger(unittest.TestCase):
sum(xrange(self.N)) * 3)
def test_huge_dataset(self):
m = ExternalMerger(self.agg, 10)
m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
@ -123,7 +123,7 @@ class TestMerger(unittest.TestCase):
m._cleanup()
class TestSorter(unittest.TestCase):
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
@ -244,16 +244,25 @@ class PySparkTestCase(unittest.TestCase):
sys.path = self._old_sys_path
class TestCheckpoint(PySparkTestCase):
class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
@classmethod
def tearDownClass(cls):
cls.sc.stop()
class CheckpointTests(ReusedPySparkTestCase):
def setUp(self):
PySparkTestCase.setUp(self)
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@ -288,7 +297,7 @@ class TestCheckpoint(PySparkTestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
class TestAddFile(PySparkTestCase):
class AddFileTests(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
@ -354,7 +363,7 @@ class TestAddFile(PySparkTestCase):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
class TestRDDFunctions(PySparkTestCase):
class RDDTests(ReusedPySparkTestCase):
def test_id(self):
rdd = self.sc.parallelize(range(10))
@ -365,12 +374,6 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())
def test_failed_sparkcontext_creation(self):
# Regression test for SPARK-1550
self.sc.stop()
self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
self.sc = SparkContext("local")
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
@ -636,7 +639,7 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(result.count(), 3)
class TestProfiler(PySparkTestCase):
class ProfilerTests(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
@ -666,10 +669,9 @@ class TestProfiler(PySparkTestCase):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
class TestSQL(PySparkTestCase):
class SQLTests(ReusedPySparkTestCase):
def setUp(self):
PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)
def test_udf(self):
@ -754,27 +756,19 @@ class TestSQL(PySparkTestCase):
self.assertEqual("2", row.d)
class TestIO(PySparkTestCase):
class InputFormatTests(ReusedPySparkTestCase):
def test_stdout_redirection(self):
import subprocess
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
def func(x):
subprocess.check_call('ls', shell=True)
self.sc.parallelize([1]).foreach(func)
class TestInputFormat(PySparkTestCase):
def setUp(self):
PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)
def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name)
@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name)
def test_sequencefiles(self):
basepath = self.tempdir.name
@ -954,15 +948,13 @@ class TestInputFormat(PySparkTestCase):
self.assertEqual(maps, em)
class TestOutputFormat(PySparkTestCase):
class OutputFormatTests(ReusedPySparkTestCase):
def setUp(self):
PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name, ignore_errors=True)
def test_sequencefiles(self):
@ -1243,8 +1235,7 @@ class TestOutputFormat(PySparkTestCase):
basepath + "/malformed/sequence"))
class TestDaemon(unittest.TestCase):
class DaemonTests(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
@ -1290,7 +1281,7 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
class TestWorker(PySparkTestCase):
class WorkerTests(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
@ -1342,11 +1333,6 @@ class TestWorker(PySparkTestCase):
rdd = self.sc.parallelize(range(100), 1)
self.assertEqual(100, rdd.map(str).count())
def test_fd_leak(self):
N = 1100 # fd limit is 1024 by default
rdd = self.sc.parallelize(range(N), N)
self.assertEquals(N, rdd.count())
def test_after_exception(self):
def raise_exception(_):
raise Exception()
@ -1379,7 +1365,7 @@ class TestWorker(PySparkTestCase):
self.assertEqual(sum(range(100)), acc1.value)
class TestSparkSubmit(unittest.TestCase):
class SparkSubmitTests(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
@ -1492,6 +1478,8 @@ class TestSparkSubmit(unittest.TestCase):
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
""")
# this will fail if you have different spark.executor.memory
# in conf/spark-defaults.conf
proc = subprocess.Popen(
[self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
stdout=subprocess.PIPE)
@ -1500,7 +1488,11 @@ class TestSparkSubmit(unittest.TestCase):
self.assertIn("[2, 4, 6]", out)
class ContextStopTests(unittest.TestCase):
class ContextTests(unittest.TestCase):
def test_failed_sparkcontext_creation(self):
# Regression test for SPARK-1550
self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
def test_stop(self):
sc = SparkContext()

View file

@ -34,7 +34,7 @@ rm -rf metastore warehouse
function run_test() {
echo "Running test: $1"
SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
FAILED=$((PIPESTATUS[0]||$FAILED))
@ -48,6 +48,37 @@ function run_test() {
fi
}
function run_core_tests() {
echo "Run core tests ..."
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}
function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql.py"
}
function run_mllib_tests() {
echo "Run mllib tests ..."
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/stat.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
run_test "pyspark/mllib/tests.py"
}
echo "Running PySpark tests. Output is in python/unit-tests.log."
export PYSPARK_PYTHON="python"
@ -60,29 +91,9 @@ fi
echo "Testing with Python version:"
$PYSPARK_PYTHON --version
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
run_test "pyspark/sql.py"
# These tests are included in the module-level docs, and so must
# be handled on a higher level rather than within the python file.
export PYSPARK_DOC_TEST=1
run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
unset PYSPARK_DOC_TEST
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/stat.py"
run_test "pyspark/mllib/tests.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
run_core_tests
run_sql_tests
run_mllib_tests
# Try to test with PyPy
if [ $(which pypy) ]; then
@ -90,19 +101,8 @@ if [ $(which pypy) ]; then
echo "Testing with PyPy version:"
$PYSPARK_PYTHON --version
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
run_test "pyspark/sql.py"
# These tests are included in the module-level docs, and so must
# be handled on a higher level rather than within the python file.
export PYSPARK_DOC_TEST=1
run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
unset PYSPARK_DOC_TEST
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
run_core_tests
run_sql_tests
fi
if [[ $FAILED == 0 ]]; then