spark-instrumented-optimizer/python/pyspark/streaming/tests/test_dstream.py
Hyukjin Kwon cdd694c52b [SPARK-7721][INFRA] Run and generate test coverage report from Python via Jenkins
## What changes were proposed in this pull request?

### Background

For the current status, the test script that generates coverage information was merged
into Spark, https://github.com/apache/spark/pull/20204

So, we can generate the coverage report and site by, for example:

```
run-tests-with-coverage --python-executables=python3 --modules=pyspark-sql
```

like `run-tests` script in `./python`.

### Proposed change

The next step is to host this coverage report via `github.io` automatically
by Jenkins (see https://spark-test.github.io/pyspark-coverage-site/).

This uses my testing account for Spark, spark-test, which is shared to Felix and Shivaram a long time ago for testing purpose including AppVeyor.

To cut this short, this PR targets to run the coverage in
[spark-master-test-sbt-hadoop-2.7](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/)

In the specific job, it will clone the page, and rebase the up-to-date PySpark test coverage from the latest commit. For instance as below:

```bash
# Clone PySpark coverage site.
git clone https://github.com/spark-test/pyspark-coverage-site.git

# Remove existing HTMLs.
rm -fr pyspark-coverage-site/*

# Copy generated coverage HTMLs.
cp -r .../python/test_coverage/htmlcov/* pyspark-coverage-site/

# Check out to a temporary branch.
git symbolic-ref HEAD refs/heads/latest_branch

# Add all the files.
git add -A

# Commit current HTMLs.
git commit -am "Coverage report at latest commit in Apache Spark"

# Delete the old branch.
git branch -D gh-pages

# Rename the temporary branch to master.
git branch -m gh-pages

# Finally, force update to our repository.
git push -f origin gh-pages
```

So, it is a one single up-to-date coverage can be shown in the `github-io` page. The commands above were manually tested.

### TODOs

- [x] Write a draft HyukjinKwon
- [x] `pip install coverage` to all python implementations (pypy, python2, python3) in Jenkins workers  - shaneknapp
- [x] Set hidden `SPARK_TEST_KEY` for spark-test's password in Jenkins via Jenkins's feature
 This should be set in both PR builder and `spark-master-test-sbt-hadoop-2.7` so that later other PRs can test and fix the bugs - shaneknapp
- [x] Set an environment variable that indicates `spark-master-test-sbt-hadoop-2.7` so that that specific build can report and update the coverage site - shaneknapp
- [x] Make PR builder's test passed HyukjinKwon
- [x] Fix flaky test related with coverage HyukjinKwon
  -  6 consecutive passes out of 7 runs

This PR will be co-authored with me and shaneknapp

## How was this patch tested?

It will be tested via Jenkins.

Closes #23117 from HyukjinKwon/SPARK-7721.

Lead-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Co-authored-by: hyukjinkwon <gurwls223@apache.org>
Co-authored-by: shane knapp <incomplete@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
2019-02-01 10:18:08 +08:00

651 lines
24 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import os
import shutil
import tempfile
import time
import unittest
from functools import reduce
from itertools import chain
import platform
from pyspark import SparkConf, SparkContext, RDD
from pyspark.streaming import StreamingContext
from pyspark.testing.streamingutils import PySparkStreamingTestCase
@unittest.skipIf(
"pypy" in platform.python_implementation().lower() and "COVERAGE_PROCESS_START" in os.environ,
"PyPy implementation causes to hang DStream tests forever when Coverage report is used.")
class BasicOperationTests(PySparkStreamingTestCase):
def test_map(self):
"""Basic operation test for DStream.map."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.map(str)
expected = [list(map(str, x)) for x in input]
self._test_func(input, func, expected)
def test_flatMap(self):
"""Basic operation test for DStream.flatMap."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.flatMap(lambda x: (x, x * 2))
expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x))))
for x in input]
self._test_func(input, func, expected)
def test_filter(self):
"""Basic operation test for DStream.filter."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.filter(lambda x: x % 2 == 0)
expected = [[y for y in x if y % 2 == 0] for x in input]
self._test_func(input, func, expected)
def test_count(self):
"""Basic operation test for DStream.count."""
input = [range(5), range(10), range(20)]
def func(dstream):
return dstream.count()
expected = [[len(x)] for x in input]
self._test_func(input, func, expected)
def test_slice(self):
"""Basic operation test for DStream.slice."""
import datetime as dt
self.ssc = StreamingContext(self.sc, 1.0)
self.ssc.remember(4.0)
input = [[1], [2], [3], [4]]
stream = self.ssc.queueStream([self.sc.parallelize(d, 1) for d in input])
time_vals = []
def get_times(t, rdd):
if rdd and len(time_vals) < len(input):
time_vals.append(t)
stream.foreachRDD(get_times)
self.ssc.start()
self.wait_for(time_vals, 4)
begin_time = time_vals[0]
def get_sliced(begin_delta, end_delta):
begin = begin_time + dt.timedelta(seconds=begin_delta)
end = begin_time + dt.timedelta(seconds=end_delta)
rdds = stream.slice(begin, end)
result_list = [rdd.collect() for rdd in rdds]
return [r for result in result_list for r in result]
self.assertEqual(set([1]), set(get_sliced(0, 0)))
self.assertEqual(set([2, 3]), set(get_sliced(1, 2)))
self.assertEqual(set([2, 3, 4]), set(get_sliced(1, 4)))
self.assertEqual(set([1, 2, 3, 4]), set(get_sliced(0, 4)))
def test_reduce(self):
"""Basic operation test for DStream.reduce."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.reduce(operator.add)
expected = [[reduce(operator.add, x)] for x in input]
self._test_func(input, func, expected)
def test_reduceByKey(self):
"""Basic operation test for DStream.reduceByKey."""
input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
[("", 1), ("", 1), ("", 1), ("", 1)],
[(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
def func(dstream):
return dstream.reduceByKey(operator.add)
expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
self._test_func(input, func, expected, sort=True)
def test_mapValues(self):
"""Basic operation test for DStream.mapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[(0, 4), (1, 1), (2, 2), (3, 3)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.mapValues(lambda x: x + 10)
expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
[(0, 14), (1, 11), (2, 12), (3, 13)],
[(1, 11), (2, 11), (3, 11), (4, 11)]]
self._test_func(input, func, expected, sort=True)
def test_flatMapValues(self):
"""Basic operation test for DStream.flatMapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[(0, 4), (1, 1), (2, 1), (3, 1)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.flatMapValues(lambda x: (x, x + 10))
expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
("c", 1), ("c", 11), ("d", 1), ("d", 11)],
[(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
[(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
self._test_func(input, func, expected)
def test_glom(self):
"""Basic operation test for DStream.glom."""
input = [range(1, 5), range(5, 9), range(9, 13)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
return dstream.glom()
expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
self._test_func(rdds, func, expected)
def test_mapPartitions(self):
"""Basic operation test for DStream.mapPartitions."""
input = [range(1, 5), range(5, 9), range(9, 13)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
def f(iterator):
yield sum(iterator)
return dstream.mapPartitions(f)
expected = [[3, 7], [11, 15], [19, 23]]
self._test_func(rdds, func, expected)
def test_countByValue(self):
"""Basic operation test for DStream.countByValue."""
input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]]
def func(dstream):
return dstream.countByValue()
expected = [[(1, 2), (2, 2), (3, 2), (4, 2)],
[(5, 2), (6, 2), (7, 1), (8, 1)],
[("a", 2), ("b", 1), ("", 1)]]
self._test_func(input, func, expected, sort=True)
def test_groupByKey(self):
"""Basic operation test for DStream.groupByKey."""
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
def func(dstream):
return dstream.groupByKey().mapValues(list)
expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
[(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
[("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
self._test_func(input, func, expected, sort=True)
def test_combineByKey(self):
"""Basic operation test for DStream.combineByKey."""
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
def func(dstream):
def add(a, b):
return a + str(b)
return dstream.combineByKey(str, add, add)
expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
[(1, "111"), (2, "11"), (3, "1")],
[("a", "11"), ("b", "1"), ("", "111")]]
self._test_func(input, func, expected, sort=True)
def test_repartition(self):
input = [range(1, 5), range(5, 9)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
return dstream.repartition(1).glom()
expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
self._test_func(rdds, func, expected)
def test_union(self):
input1 = [range(3), range(5), range(6)]
input2 = [range(3, 6), range(5, 6)]
def func(d1, d2):
return d1.union(d2)
expected = [list(range(6)), list(range(6)), list(range(6))]
self._test_func(input1, func, expected, input2=input2)
def test_cogroup(self):
input = [[(1, 1), (2, 1), (3, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
input2 = [[(1, 2)],
[(4, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
def func(d1, d2):
return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
[(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
[("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
self._test_func(input, func, expected, sort=True, input2=input2)
def test_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.join(b)
expected = [[('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)
def test_left_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.leftOuterJoin(b)
expected = [[('a', (1, None)), ('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)
def test_right_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.rightOuterJoin(b)
expected = [[('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)
def test_full_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.fullOuterJoin(b)
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)
def test_update_state_by_key(self):
def updater(vs, s):
if not s:
s = []
s.extend(vs)
return s
input = [[('k', i)] for i in range(5)]
def func(dstream):
return dstream.updateStateByKey(updater)
expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
expected = [[('k', v)] for v in expected]
self._test_func(input, func, expected)
def test_update_state_by_key_initial_rdd(self):
def updater(vs, s):
if not s:
s = []
s.extend(vs)
return s
initial = [('k', [0, 1])]
initial = self.sc.parallelize(initial, 1)
input = [[('k', i)] for i in range(2, 5)]
def func(dstream):
return dstream.updateStateByKey(updater, initialRDD=initial)
expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
expected = [[('k', v)] for v in expected]
self._test_func(input, func, expected)
def test_failed_func(self):
# Test failure in
# TransformFunction.apply(rdd: Option[RDD[_]], time: Time)
input = [self.sc.parallelize([d], 1) for d in range(4)]
input_stream = self.ssc.queueStream(input)
def failed_func(i):
raise ValueError("This is a special error")
input_stream.map(failed_func).pprint()
self.ssc.start()
try:
self.ssc.awaitTerminationOrTimeout(10)
except:
import traceback
failure = traceback.format_exc()
self.assertTrue("This is a special error" in failure)
return
self.fail("a failed func should throw an error")
def test_failed_func2(self):
# Test failure in
# TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time)
input = [self.sc.parallelize([d], 1) for d in range(4)]
input_stream1 = self.ssc.queueStream(input)
input_stream2 = self.ssc.queueStream(input)
def failed_func(rdd1, rdd2):
raise ValueError("This is a special error")
input_stream1.transformWith(failed_func, input_stream2, True).pprint()
self.ssc.start()
try:
self.ssc.awaitTerminationOrTimeout(10)
except:
import traceback
failure = traceback.format_exc()
self.assertTrue("This is a special error" in failure)
return
self.fail("a failed func should throw an error")
def test_failed_func_with_reseting_failure(self):
input = [self.sc.parallelize([d], 1) for d in range(4)]
input_stream = self.ssc.queueStream(input)
def failed_func(i):
if i == 1:
# Make it fail in the second batch
raise ValueError("This is a special error")
else:
return i
# We should be able to see the results of the 3rd and 4th batches even if the second batch
# fails
expected = [[0], [2], [3]]
self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3))
try:
self.ssc.awaitTerminationOrTimeout(10)
except:
import traceback
failure = traceback.format_exc()
self.assertTrue("This is a special error" in failure)
return
self.fail("a failed func should throw an error")
@unittest.skipIf(
"pypy" in platform.python_implementation().lower() and "COVERAGE_PROCESS_START" in os.environ,
"PyPy implementation causes to hang DStream tests forever when Coverage report is used.")
class WindowFunctionTests(PySparkStreamingTestCase):
timeout = 15
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
return dstream.window(1.5, .5).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
return dstream.countByWindow(1.5, .5)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.countByWindow(2.5, .5)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.countByValueAndWindow(2.5, .5)
expected = [[(0, 1)],
[(0, 2), (1, 1)],
[(0, 3), (1, 2), (2, 1)],
[(0, 4), (1, 3), (2, 2), (3, 1)],
[(0, 5), (1, 4), (2, 3), (3, 2), (4, 1)],
[(0, 5), (1, 5), (2, 4), (3, 3), (4, 2), (5, 1)],
[(0, 4), (1, 4), (2, 4), (3, 3), (4, 2), (5, 1)],
[(0, 3), (1, 3), (2, 3), (3, 3), (4, 2), (5, 1)],
[(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 1)],
[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]]
self._test_func(input, func, expected)
def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]
def func(dstream):
return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
self._test_func(input, func, expected)
def test_reduce_by_invalid_window(self):
input1 = [range(3), range(5), range(1), range(6)]
d1 = self.ssc.queueStream(input1)
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
def test_reduce_by_key_and_window_with_none_invFunc(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.map(lambda x: (x, 1))\
.reduceByKeyAndWindow(operator.add, None, 5, 1)\
.filter(lambda kv: kv[1] > 0).count()
expected = [[2], [4], [6], [6], [6], [6]]
self._test_func(input, func, expected)
@unittest.skipIf(
"pypy" in platform.python_implementation().lower() and "COVERAGE_PROCESS_START" in os.environ,
"PyPy implementation causes to hang DStream tests forever when Coverage report is used.")
class CheckpointTests(unittest.TestCase):
setupCalled = False
@staticmethod
def tearDownClass():
# Clean up in the JVM just in case there has been some issues in Python API
if SparkContext._jvm is not None:
jStreamingContextOption = \
SparkContext._jvm.org.apache.spark.streaming.StreamingContext.getActive()
if jStreamingContextOption.nonEmpty():
jStreamingContextOption.get().stop()
def setUp(self):
self.ssc = None
self.sc = None
self.cpd = None
def tearDown(self):
if self.ssc is not None:
self.ssc.stop(True)
if self.sc is not None:
self.sc.stop()
if self.cpd is not None:
shutil.rmtree(self.cpd)
def test_transform_function_serializer_failure(self):
inputd = tempfile.mkdtemp()
self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure")
def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, 0.5)
# A function that cannot be serialized
def process(time, rdd):
sc.parallelize(range(1, 10))
ssc.textFileStream(inputd).foreachRDD(process)
return ssc
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
try:
self.ssc.start()
except:
import traceback
failure = traceback.format_exc()
self.assertTrue(
"It appears that you are attempting to reference SparkContext" in failure)
return
self.fail("using SparkContext in process should fail because it's not Serializable")
def test_get_or_create_and_get_active_or_create(self):
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"
def updater(vs, s):
return sum(vs, s or 0)
def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, 2)
dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
wc = dstream.updateStateByKey(updater)
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
wc.checkpoint(2)
self.setupCalled = True
return ssc
# Verify that getOrCreate() calls setup() in absence of checkpoint files
self.cpd = tempfile.mkdtemp("test_streaming_cps")
self.setupCalled = False
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
self.assertTrue(self.setupCalled)
self.ssc.start()
def check_output(n):
while not os.listdir(outputd):
if self.ssc.awaitTerminationOrTimeout(0.5):
raise Exception("ssc stopped")
time.sleep(1) # make sure mtime is larger than the previous one
with open(os.path.join(inputd, str(n)), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
while True:
if self.ssc.awaitTerminationOrTimeout(0.5):
raise Exception("ssc stopped")
p = os.path.join(outputd, max(os.listdir(outputd)))
if '_SUCCESS' not in os.listdir(p):
# not finished
continue
ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
d = ordd.values().map(int).collect()
if not d:
continue
self.assertEqual(10, len(d))
s = set(d)
self.assertEqual(1, len(s))
m = s.pop()
if n > m:
continue
self.assertEqual(n, m)
break
check_output(1)
check_output(2)
# Verify the getOrCreate() recovers from checkpoint files
self.ssc.stop(True, True)
time.sleep(1)
self.setupCalled = False
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
self.assertFalse(self.setupCalled)
self.ssc.start()
check_output(3)
# Verify that getOrCreate() uses existing SparkContext
self.ssc.stop(True, True)
time.sleep(1)
self.sc = SparkContext(conf=SparkConf())
self.setupCalled = False
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
self.assertFalse(self.setupCalled)
self.assertTrue(self.ssc.sparkContext == self.sc)
# Verify the getActiveOrCreate() recovers from checkpoint files
self.ssc.stop(True, True)
time.sleep(1)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
self.assertFalse(self.setupCalled)
self.ssc.start()
check_output(4)
# Verify that getActiveOrCreate() returns active context
self.setupCalled = False
self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc)
self.assertFalse(self.setupCalled)
# Verify that getActiveOrCreate() uses existing SparkContext
self.ssc.stop(True, True)
time.sleep(1)
self.sc = SparkContext(conf=SparkConf())
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
self.assertFalse(self.setupCalled)
self.assertTrue(self.ssc.sparkContext == self.sc)
# Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
self.ssc.stop(True, True)
shutil.rmtree(self.cpd) # delete checkpoint directory
time.sleep(1)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
self.assertTrue(self.setupCalled)
# Stop everything
self.ssc.stop(True, True)
if __name__ == "__main__":
from pyspark.streaming.tests.test_dstream import *
try:
import xmlrunner
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
except ImportError:
unittest.main(verbosity=2)