3649fe599f
## What changes were proposed in this pull request? This PR continues to break down a big large file into smaller files. See https://github.com/apache/spark/pull/23021. It targets to follow https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/streaming/tests.py` into ...: ``` pyspark ├── __init__.py ... ├── streaming │ ├── __init__.py ... │ ├── tests │ │ ├── __init__.py │ │ ├── test_context.py │ │ ├── test_dstream.py │ │ ├── test_kinesis.py │ │ └── test_listener.py ... ├── testing ... │ ├── streamingutils.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and .`/run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.tests.test_context ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23034 from HyukjinKwon/SPARK-26035. Authored-by: hyukjinkwon <gurwls223@apache.org> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
641 lines
23 KiB
Python
641 lines
23 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import operator
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
from functools import reduce
|
|
from itertools import chain
|
|
|
|
from pyspark import SparkConf, SparkContext, RDD
|
|
from pyspark.streaming import StreamingContext
|
|
from pyspark.testing.streamingutils import PySparkStreamingTestCase
|
|
|
|
|
|
class BasicOperationTests(PySparkStreamingTestCase):
|
|
|
|
def test_map(self):
|
|
"""Basic operation test for DStream.map."""
|
|
input = [range(1, 5), range(5, 9), range(9, 13)]
|
|
|
|
def func(dstream):
|
|
return dstream.map(str)
|
|
expected = [list(map(str, x)) for x in input]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_flatMap(self):
|
|
"""Basic operation test for DStream.flatMap."""
|
|
input = [range(1, 5), range(5, 9), range(9, 13)]
|
|
|
|
def func(dstream):
|
|
return dstream.flatMap(lambda x: (x, x * 2))
|
|
expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x))))
|
|
for x in input]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_filter(self):
|
|
"""Basic operation test for DStream.filter."""
|
|
input = [range(1, 5), range(5, 9), range(9, 13)]
|
|
|
|
def func(dstream):
|
|
return dstream.filter(lambda x: x % 2 == 0)
|
|
expected = [[y for y in x if y % 2 == 0] for x in input]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_count(self):
|
|
"""Basic operation test for DStream.count."""
|
|
input = [range(5), range(10), range(20)]
|
|
|
|
def func(dstream):
|
|
return dstream.count()
|
|
expected = [[len(x)] for x in input]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_slice(self):
|
|
"""Basic operation test for DStream.slice."""
|
|
import datetime as dt
|
|
self.ssc = StreamingContext(self.sc, 1.0)
|
|
self.ssc.remember(4.0)
|
|
input = [[1], [2], [3], [4]]
|
|
stream = self.ssc.queueStream([self.sc.parallelize(d, 1) for d in input])
|
|
|
|
time_vals = []
|
|
|
|
def get_times(t, rdd):
|
|
if rdd and len(time_vals) < len(input):
|
|
time_vals.append(t)
|
|
|
|
stream.foreachRDD(get_times)
|
|
|
|
self.ssc.start()
|
|
self.wait_for(time_vals, 4)
|
|
begin_time = time_vals[0]
|
|
|
|
def get_sliced(begin_delta, end_delta):
|
|
begin = begin_time + dt.timedelta(seconds=begin_delta)
|
|
end = begin_time + dt.timedelta(seconds=end_delta)
|
|
rdds = stream.slice(begin, end)
|
|
result_list = [rdd.collect() for rdd in rdds]
|
|
return [r for result in result_list for r in result]
|
|
|
|
self.assertEqual(set([1]), set(get_sliced(0, 0)))
|
|
self.assertEqual(set([2, 3]), set(get_sliced(1, 2)))
|
|
self.assertEqual(set([2, 3, 4]), set(get_sliced(1, 4)))
|
|
self.assertEqual(set([1, 2, 3, 4]), set(get_sliced(0, 4)))
|
|
|
|
def test_reduce(self):
|
|
"""Basic operation test for DStream.reduce."""
|
|
input = [range(1, 5), range(5, 9), range(9, 13)]
|
|
|
|
def func(dstream):
|
|
return dstream.reduce(operator.add)
|
|
expected = [[reduce(operator.add, x)] for x in input]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_reduceByKey(self):
|
|
"""Basic operation test for DStream.reduceByKey."""
|
|
input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
|
|
[("", 1), ("", 1), ("", 1), ("", 1)],
|
|
[(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
|
|
|
|
def func(dstream):
|
|
return dstream.reduceByKey(operator.add)
|
|
expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
|
|
self._test_func(input, func, expected, sort=True)
|
|
|
|
def test_mapValues(self):
|
|
"""Basic operation test for DStream.mapValues."""
|
|
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
|
|
[(0, 4), (1, 1), (2, 2), (3, 3)],
|
|
[(1, 1), (2, 1), (3, 1), (4, 1)]]
|
|
|
|
def func(dstream):
|
|
return dstream.mapValues(lambda x: x + 10)
|
|
expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
|
|
[(0, 14), (1, 11), (2, 12), (3, 13)],
|
|
[(1, 11), (2, 11), (3, 11), (4, 11)]]
|
|
self._test_func(input, func, expected, sort=True)
|
|
|
|
def test_flatMapValues(self):
|
|
"""Basic operation test for DStream.flatMapValues."""
|
|
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
|
|
[(0, 4), (1, 1), (2, 1), (3, 1)],
|
|
[(1, 1), (2, 1), (3, 1), (4, 1)]]
|
|
|
|
def func(dstream):
|
|
return dstream.flatMapValues(lambda x: (x, x + 10))
|
|
expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
|
|
("c", 1), ("c", 11), ("d", 1), ("d", 11)],
|
|
[(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
|
|
[(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_glom(self):
|
|
"""Basic operation test for DStream.glom."""
|
|
input = [range(1, 5), range(5, 9), range(9, 13)]
|
|
rdds = [self.sc.parallelize(r, 2) for r in input]
|
|
|
|
def func(dstream):
|
|
return dstream.glom()
|
|
expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
|
|
self._test_func(rdds, func, expected)
|
|
|
|
def test_mapPartitions(self):
|
|
"""Basic operation test for DStream.mapPartitions."""
|
|
input = [range(1, 5), range(5, 9), range(9, 13)]
|
|
rdds = [self.sc.parallelize(r, 2) for r in input]
|
|
|
|
def func(dstream):
|
|
def f(iterator):
|
|
yield sum(iterator)
|
|
return dstream.mapPartitions(f)
|
|
expected = [[3, 7], [11, 15], [19, 23]]
|
|
self._test_func(rdds, func, expected)
|
|
|
|
def test_countByValue(self):
|
|
"""Basic operation test for DStream.countByValue."""
|
|
input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]]
|
|
|
|
def func(dstream):
|
|
return dstream.countByValue()
|
|
expected = [[(1, 2), (2, 2), (3, 2), (4, 2)],
|
|
[(5, 2), (6, 2), (7, 1), (8, 1)],
|
|
[("a", 2), ("b", 1), ("", 1)]]
|
|
self._test_func(input, func, expected, sort=True)
|
|
|
|
def test_groupByKey(self):
|
|
"""Basic operation test for DStream.groupByKey."""
|
|
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
|
|
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
|
|
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
|
|
|
|
def func(dstream):
|
|
return dstream.groupByKey().mapValues(list)
|
|
|
|
expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
|
|
[(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
|
|
[("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
|
|
self._test_func(input, func, expected, sort=True)
|
|
|
|
def test_combineByKey(self):
|
|
"""Basic operation test for DStream.combineByKey."""
|
|
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
|
|
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
|
|
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
|
|
|
|
def func(dstream):
|
|
def add(a, b):
|
|
return a + str(b)
|
|
return dstream.combineByKey(str, add, add)
|
|
expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
|
|
[(1, "111"), (2, "11"), (3, "1")],
|
|
[("a", "11"), ("b", "1"), ("", "111")]]
|
|
self._test_func(input, func, expected, sort=True)
|
|
|
|
def test_repartition(self):
|
|
input = [range(1, 5), range(5, 9)]
|
|
rdds = [self.sc.parallelize(r, 2) for r in input]
|
|
|
|
def func(dstream):
|
|
return dstream.repartition(1).glom()
|
|
expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
|
|
self._test_func(rdds, func, expected)
|
|
|
|
def test_union(self):
|
|
input1 = [range(3), range(5), range(6)]
|
|
input2 = [range(3, 6), range(5, 6)]
|
|
|
|
def func(d1, d2):
|
|
return d1.union(d2)
|
|
|
|
expected = [list(range(6)), list(range(6)), list(range(6))]
|
|
self._test_func(input1, func, expected, input2=input2)
|
|
|
|
def test_cogroup(self):
|
|
input = [[(1, 1), (2, 1), (3, 1)],
|
|
[(1, 1), (1, 1), (1, 1), (2, 1)],
|
|
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
|
|
input2 = [[(1, 2)],
|
|
[(4, 1)],
|
|
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
|
|
|
|
def func(d1, d2):
|
|
return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
|
|
|
|
expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
|
|
[(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
|
|
[("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
|
|
self._test_func(input, func, expected, sort=True, input2=input2)
|
|
|
|
def test_join(self):
|
|
input = [[('a', 1), ('b', 2)]]
|
|
input2 = [[('b', 3), ('c', 4)]]
|
|
|
|
def func(a, b):
|
|
return a.join(b)
|
|
|
|
expected = [[('b', (2, 3))]]
|
|
self._test_func(input, func, expected, True, input2)
|
|
|
|
def test_left_outer_join(self):
|
|
input = [[('a', 1), ('b', 2)]]
|
|
input2 = [[('b', 3), ('c', 4)]]
|
|
|
|
def func(a, b):
|
|
return a.leftOuterJoin(b)
|
|
|
|
expected = [[('a', (1, None)), ('b', (2, 3))]]
|
|
self._test_func(input, func, expected, True, input2)
|
|
|
|
def test_right_outer_join(self):
|
|
input = [[('a', 1), ('b', 2)]]
|
|
input2 = [[('b', 3), ('c', 4)]]
|
|
|
|
def func(a, b):
|
|
return a.rightOuterJoin(b)
|
|
|
|
expected = [[('b', (2, 3)), ('c', (None, 4))]]
|
|
self._test_func(input, func, expected, True, input2)
|
|
|
|
def test_full_outer_join(self):
|
|
input = [[('a', 1), ('b', 2)]]
|
|
input2 = [[('b', 3), ('c', 4)]]
|
|
|
|
def func(a, b):
|
|
return a.fullOuterJoin(b)
|
|
|
|
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
|
|
self._test_func(input, func, expected, True, input2)
|
|
|
|
def test_update_state_by_key(self):
|
|
|
|
def updater(vs, s):
|
|
if not s:
|
|
s = []
|
|
s.extend(vs)
|
|
return s
|
|
|
|
input = [[('k', i)] for i in range(5)]
|
|
|
|
def func(dstream):
|
|
return dstream.updateStateByKey(updater)
|
|
|
|
expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
|
|
expected = [[('k', v)] for v in expected]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_update_state_by_key_initial_rdd(self):
|
|
|
|
def updater(vs, s):
|
|
if not s:
|
|
s = []
|
|
s.extend(vs)
|
|
return s
|
|
|
|
initial = [('k', [0, 1])]
|
|
initial = self.sc.parallelize(initial, 1)
|
|
|
|
input = [[('k', i)] for i in range(2, 5)]
|
|
|
|
def func(dstream):
|
|
return dstream.updateStateByKey(updater, initialRDD=initial)
|
|
|
|
expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
|
|
expected = [[('k', v)] for v in expected]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_failed_func(self):
|
|
# Test failure in
|
|
# TransformFunction.apply(rdd: Option[RDD[_]], time: Time)
|
|
input = [self.sc.parallelize([d], 1) for d in range(4)]
|
|
input_stream = self.ssc.queueStream(input)
|
|
|
|
def failed_func(i):
|
|
raise ValueError("This is a special error")
|
|
|
|
input_stream.map(failed_func).pprint()
|
|
self.ssc.start()
|
|
try:
|
|
self.ssc.awaitTerminationOrTimeout(10)
|
|
except:
|
|
import traceback
|
|
failure = traceback.format_exc()
|
|
self.assertTrue("This is a special error" in failure)
|
|
return
|
|
|
|
self.fail("a failed func should throw an error")
|
|
|
|
def test_failed_func2(self):
|
|
# Test failure in
|
|
# TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time)
|
|
input = [self.sc.parallelize([d], 1) for d in range(4)]
|
|
input_stream1 = self.ssc.queueStream(input)
|
|
input_stream2 = self.ssc.queueStream(input)
|
|
|
|
def failed_func(rdd1, rdd2):
|
|
raise ValueError("This is a special error")
|
|
|
|
input_stream1.transformWith(failed_func, input_stream2, True).pprint()
|
|
self.ssc.start()
|
|
try:
|
|
self.ssc.awaitTerminationOrTimeout(10)
|
|
except:
|
|
import traceback
|
|
failure = traceback.format_exc()
|
|
self.assertTrue("This is a special error" in failure)
|
|
return
|
|
|
|
self.fail("a failed func should throw an error")
|
|
|
|
def test_failed_func_with_reseting_failure(self):
|
|
input = [self.sc.parallelize([d], 1) for d in range(4)]
|
|
input_stream = self.ssc.queueStream(input)
|
|
|
|
def failed_func(i):
|
|
if i == 1:
|
|
# Make it fail in the second batch
|
|
raise ValueError("This is a special error")
|
|
else:
|
|
return i
|
|
|
|
# We should be able to see the results of the 3rd and 4th batches even if the second batch
|
|
# fails
|
|
expected = [[0], [2], [3]]
|
|
self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3))
|
|
try:
|
|
self.ssc.awaitTerminationOrTimeout(10)
|
|
except:
|
|
import traceback
|
|
failure = traceback.format_exc()
|
|
self.assertTrue("This is a special error" in failure)
|
|
return
|
|
|
|
self.fail("a failed func should throw an error")
|
|
|
|
|
|
class WindowFunctionTests(PySparkStreamingTestCase):
|
|
|
|
timeout = 15
|
|
|
|
def test_window(self):
|
|
input = [range(1), range(2), range(3), range(4), range(5)]
|
|
|
|
def func(dstream):
|
|
return dstream.window(1.5, .5).count()
|
|
|
|
expected = [[1], [3], [6], [9], [12], [9], [5]]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_count_by_window(self):
|
|
input = [range(1), range(2), range(3), range(4), range(5)]
|
|
|
|
def func(dstream):
|
|
return dstream.countByWindow(1.5, .5)
|
|
|
|
expected = [[1], [3], [6], [9], [12], [9], [5]]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_count_by_window_large(self):
|
|
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
|
|
|
|
def func(dstream):
|
|
return dstream.countByWindow(2.5, .5)
|
|
|
|
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_count_by_value_and_window(self):
|
|
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
|
|
|
|
def func(dstream):
|
|
return dstream.countByValueAndWindow(2.5, .5)
|
|
|
|
expected = [[(0, 1)],
|
|
[(0, 2), (1, 1)],
|
|
[(0, 3), (1, 2), (2, 1)],
|
|
[(0, 4), (1, 3), (2, 2), (3, 1)],
|
|
[(0, 5), (1, 4), (2, 3), (3, 2), (4, 1)],
|
|
[(0, 5), (1, 5), (2, 4), (3, 3), (4, 2), (5, 1)],
|
|
[(0, 4), (1, 4), (2, 4), (3, 3), (4, 2), (5, 1)],
|
|
[(0, 3), (1, 3), (2, 3), (3, 3), (4, 2), (5, 1)],
|
|
[(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 1)],
|
|
[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_group_by_key_and_window(self):
|
|
input = [[('a', i)] for i in range(5)]
|
|
|
|
def func(dstream):
|
|
return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
|
|
|
|
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
|
|
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
|
|
self._test_func(input, func, expected)
|
|
|
|
def test_reduce_by_invalid_window(self):
|
|
input1 = [range(3), range(5), range(1), range(6)]
|
|
d1 = self.ssc.queueStream(input1)
|
|
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
|
|
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
|
|
|
|
def test_reduce_by_key_and_window_with_none_invFunc(self):
|
|
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
|
|
|
|
def func(dstream):
|
|
return dstream.map(lambda x: (x, 1))\
|
|
.reduceByKeyAndWindow(operator.add, None, 5, 1)\
|
|
.filter(lambda kv: kv[1] > 0).count()
|
|
|
|
expected = [[2], [4], [6], [6], [6], [6]]
|
|
self._test_func(input, func, expected)
|
|
|
|
|
|
class CheckpointTests(unittest.TestCase):
|
|
|
|
setupCalled = False
|
|
|
|
@staticmethod
|
|
def tearDownClass():
|
|
# Clean up in the JVM just in case there has been some issues in Python API
|
|
if SparkContext._jvm is not None:
|
|
jStreamingContextOption = \
|
|
SparkContext._jvm.org.apache.spark.streaming.StreamingContext.getActive()
|
|
if jStreamingContextOption.nonEmpty():
|
|
jStreamingContextOption.get().stop()
|
|
|
|
def setUp(self):
|
|
self.ssc = None
|
|
self.sc = None
|
|
self.cpd = None
|
|
|
|
def tearDown(self):
|
|
if self.ssc is not None:
|
|
self.ssc.stop(True)
|
|
if self.sc is not None:
|
|
self.sc.stop()
|
|
if self.cpd is not None:
|
|
shutil.rmtree(self.cpd)
|
|
|
|
def test_transform_function_serializer_failure(self):
|
|
inputd = tempfile.mkdtemp()
|
|
self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure")
|
|
|
|
def setup():
|
|
conf = SparkConf().set("spark.default.parallelism", 1)
|
|
sc = SparkContext(conf=conf)
|
|
ssc = StreamingContext(sc, 0.5)
|
|
|
|
# A function that cannot be serialized
|
|
def process(time, rdd):
|
|
sc.parallelize(range(1, 10))
|
|
|
|
ssc.textFileStream(inputd).foreachRDD(process)
|
|
return ssc
|
|
|
|
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
|
|
try:
|
|
self.ssc.start()
|
|
except:
|
|
import traceback
|
|
failure = traceback.format_exc()
|
|
self.assertTrue(
|
|
"It appears that you are attempting to reference SparkContext" in failure)
|
|
return
|
|
|
|
self.fail("using SparkContext in process should fail because it's not Serializable")
|
|
|
|
def test_get_or_create_and_get_active_or_create(self):
|
|
inputd = tempfile.mkdtemp()
|
|
outputd = tempfile.mkdtemp() + "/"
|
|
|
|
def updater(vs, s):
|
|
return sum(vs, s or 0)
|
|
|
|
def setup():
|
|
conf = SparkConf().set("spark.default.parallelism", 1)
|
|
sc = SparkContext(conf=conf)
|
|
ssc = StreamingContext(sc, 2)
|
|
dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
|
|
wc = dstream.updateStateByKey(updater)
|
|
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
|
|
wc.checkpoint(2)
|
|
self.setupCalled = True
|
|
return ssc
|
|
|
|
# Verify that getOrCreate() calls setup() in absence of checkpoint files
|
|
self.cpd = tempfile.mkdtemp("test_streaming_cps")
|
|
self.setupCalled = False
|
|
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
|
|
self.assertTrue(self.setupCalled)
|
|
|
|
self.ssc.start()
|
|
|
|
def check_output(n):
|
|
while not os.listdir(outputd):
|
|
if self.ssc.awaitTerminationOrTimeout(0.5):
|
|
raise Exception("ssc stopped")
|
|
time.sleep(1) # make sure mtime is larger than the previous one
|
|
with open(os.path.join(inputd, str(n)), 'w') as f:
|
|
f.writelines(["%d\n" % i for i in range(10)])
|
|
|
|
while True:
|
|
if self.ssc.awaitTerminationOrTimeout(0.5):
|
|
raise Exception("ssc stopped")
|
|
p = os.path.join(outputd, max(os.listdir(outputd)))
|
|
if '_SUCCESS' not in os.listdir(p):
|
|
# not finished
|
|
continue
|
|
ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
|
|
d = ordd.values().map(int).collect()
|
|
if not d:
|
|
continue
|
|
self.assertEqual(10, len(d))
|
|
s = set(d)
|
|
self.assertEqual(1, len(s))
|
|
m = s.pop()
|
|
if n > m:
|
|
continue
|
|
self.assertEqual(n, m)
|
|
break
|
|
|
|
check_output(1)
|
|
check_output(2)
|
|
|
|
# Verify the getOrCreate() recovers from checkpoint files
|
|
self.ssc.stop(True, True)
|
|
time.sleep(1)
|
|
self.setupCalled = False
|
|
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
|
|
self.assertFalse(self.setupCalled)
|
|
self.ssc.start()
|
|
check_output(3)
|
|
|
|
# Verify that getOrCreate() uses existing SparkContext
|
|
self.ssc.stop(True, True)
|
|
time.sleep(1)
|
|
self.sc = SparkContext(conf=SparkConf())
|
|
self.setupCalled = False
|
|
self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
|
|
self.assertFalse(self.setupCalled)
|
|
self.assertTrue(self.ssc.sparkContext == self.sc)
|
|
|
|
# Verify the getActiveOrCreate() recovers from checkpoint files
|
|
self.ssc.stop(True, True)
|
|
time.sleep(1)
|
|
self.setupCalled = False
|
|
self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
|
|
self.assertFalse(self.setupCalled)
|
|
self.ssc.start()
|
|
check_output(4)
|
|
|
|
# Verify that getActiveOrCreate() returns active context
|
|
self.setupCalled = False
|
|
self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc)
|
|
self.assertFalse(self.setupCalled)
|
|
|
|
# Verify that getActiveOrCreate() uses existing SparkContext
|
|
self.ssc.stop(True, True)
|
|
time.sleep(1)
|
|
self.sc = SparkContext(conf=SparkConf())
|
|
self.setupCalled = False
|
|
self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
|
|
self.assertFalse(self.setupCalled)
|
|
self.assertTrue(self.ssc.sparkContext == self.sc)
|
|
|
|
# Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
|
|
self.ssc.stop(True, True)
|
|
shutil.rmtree(self.cpd) # delete checkpoint directory
|
|
time.sleep(1)
|
|
self.setupCalled = False
|
|
self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
|
|
self.assertTrue(self.setupCalled)
|
|
|
|
# Stop everything
|
|
self.ssc.stop(True, True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from pyspark.streaming.tests.test_dstream import *
|
|
|
|
try:
|
|
import xmlrunner
|
|
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
|
|
except ImportError:
|
|
unittest.main(verbosity=2)
|