spark-instrumented-optimizer/python/pyspark/streaming/tests.py
Josh Rosen 40648c56cd [SPARK-8583] [SPARK-5482] [BUILD] Refactor python/run-tests to integrate with dev/run-tests module system
This patch refactors the `python/run-tests` script:

- It's now written in Python instead of Bash.
- The descriptions of the tests to run are now stored in `dev/run-tests`'s modules.  This allows the pull request builder to skip Python tests suites that were not affected by the pull request's changes.  For example, we can now skip the PySpark Streaming test cases when only SQL files are changed.
- `python/run-tests` now supports command-line flags to make it easier to run individual test suites (this addresses SPARK-5482):

  ```
Usage: run-tests [options]

Options:
  -h, --help            show this help message and exit
  --python-executables=PYTHON_EXECUTABLES
                        A comma-separated list of Python executables to test
                        against (default: python2.6,python3.4,pypy)
  --modules=MODULES     A comma-separated list of Python modules to test
                        (default: pyspark-core,pyspark-ml,pyspark-mllib
                        ,pyspark-sql,pyspark-streaming)
   ```
- `dev/run-tests` has been split into multiple files: the module definitions and test utility functions are now stored inside of a `dev/sparktestsupport` Python module, allowing them to be re-used from the Python test runner script.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #6967 from JoshRosen/run-tests-python-modules and squashes the following commits:

f578d6d [Josh Rosen] Fix print for Python 2.x
8233d61 [Josh Rosen] Add python/run-tests.py to Python lint checks
34c98d2 [Josh Rosen] Fix universal_newlines for Python 3
8f65ed0 [Josh Rosen] Fix handling of  module in python/run-tests
37aff00 [Josh Rosen] Python 3 fix
27a389f [Josh Rosen] Skip MLLib tests for PyPy
c364ccf [Josh Rosen] Use which() to convert PYSPARK_PYTHON to an absolute path before shelling out to run tests
568a3fd [Josh Rosen] Fix hashbang
3b852ae [Josh Rosen] Fall back to PYSPARK_PYTHON when sys.executable is None (fixes a test)
f53db55 [Josh Rosen] Remove python2 flag, since the test runner script also works fine under Python 3
9c80469 [Josh Rosen] Fix passing of PYSPARK_PYTHON
d33e525 [Josh Rosen] Merge remote-tracking branch 'origin/master' into run-tests-python-modules
4f8902c [Josh Rosen] Python lint fixes.
8f3244c [Josh Rosen] Use universal_newlines to fix dev/run-tests doctest failures on Python 3.
f542ac5 [Josh Rosen] Fix lint check for Python 3
fff4d09 [Josh Rosen] Add dev/sparktestsupport to pep8 checks
2efd594 [Josh Rosen] Update dev/run-tests to use new Python test runner flags
b2ab027 [Josh Rosen] Add command-line options for running individual suites in python/run-tests
caeb040 [Josh Rosen] Fixes to PySpark test module definitions
d6a77d3 [Josh Rosen] Fix the tests of dev/run-tests
def2d8a [Josh Rosen] Two minor fixes
aec0b8f [Josh Rosen] Actually get the Kafka stuff to run properly
04015b9 [Josh Rosen] First attempt at getting PySpark Kafka test to work in new runner script
4c97136 [Josh Rosen] PYTHONPATH fixes
dcc9c09 [Josh Rosen] Fix time division
32660fc [Josh Rosen] Initial cut at Python test runner refactoring
311c6a9 [Josh Rosen] Move shell utility functions to own module.
1bdeb87 [Josh Rosen] Move module definitions to separate file.
2015-06-27 20:24:34 -07:00

697 lines
25 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import glob
import os
import sys
from itertools import chain
import time
import operator
import tempfile
import random
import struct
from functools import reduce
if sys.version_info[:2] <= (2, 6):
try:
import unittest2 as unittest
except ImportError:
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
sys.exit(1)
else:
import unittest
from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
class PySparkStreamingTestCase(unittest.TestCase):
timeout = 10 # seconds
duration = .5
@classmethod
def setUpClass(cls):
class_name = cls.__name__
conf = SparkConf().set("spark.default.parallelism", 1)
cls.sc = SparkContext(appName=class_name, conf=conf)
cls.sc.setCheckpointDir("/tmp")
@classmethod
def tearDownClass(cls):
cls.sc.stop()
def setUp(self):
self.ssc = StreamingContext(self.sc, self.duration)
def tearDown(self):
self.ssc.stop(False)
def wait_for(self, result, n):
start_time = time.time()
while len(result) < n and time.time() - start_time < self.timeout:
time.sleep(0.01)
if len(result) < n:
print("timeout after", self.timeout)
def _take(self, dstream, n):
"""
Return the first `n` elements in the stream (will start and stop).
"""
results = []
def take(_, rdd):
if rdd and len(results) < n:
results.extend(rdd.take(n - len(results)))
dstream.foreachRDD(take)
self.ssc.start()
self.wait_for(results, n)
return results
def _collect(self, dstream, n, block=True):
"""
Collect each RDDs into the returned list.
:return: list, which will have the collected items.
"""
result = []
def get_output(_, rdd):
if rdd and len(result) < n:
r = rdd.collect()
if r:
result.append(r)
dstream.foreachRDD(get_output)
if not block:
return result
self.ssc.start()
self.wait_for(result, n)
return result
def _test_func(self, input, func, expected, sort=False, input2=None):
"""
@param input: dataset for the test. This should be list of lists.
@param func: wrapped function. This function should return PythonDStream object.
@param expected: expected output for this testcase.
"""
if not isinstance(input[0], RDD):
input = [self.sc.parallelize(d, 1) for d in input]
input_stream = self.ssc.queueStream(input)
if input2 and not isinstance(input2[0], RDD):
input2 = [self.sc.parallelize(d, 1) for d in input2]
input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
# Apply test function to stream.
if input2:
stream = func(input_stream, input_stream2)
else:
stream = func(input_stream)
result = self._collect(stream, len(expected))
if sort:
self._sort_result_based_on_key(result)
self._sort_result_based_on_key(expected)
self.assertEqual(expected, result)
def _sort_result_based_on_key(self, outputs):
"""Sort the list based on first value."""
for output in outputs:
output.sort(key=lambda x: x[0])
class BasicOperationTests(PySparkStreamingTestCase):
def test_map(self):
"""Basic operation test for DStream.map."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.map(str)
expected = [list(map(str, x)) for x in input]
self._test_func(input, func, expected)
def test_flatMap(self):
"""Basic operation test for DStream.faltMap."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.flatMap(lambda x: (x, x * 2))
expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x))))
for x in input]
self._test_func(input, func, expected)
def test_filter(self):
"""Basic operation test for DStream.filter."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.filter(lambda x: x % 2 == 0)
expected = [[y for y in x if y % 2 == 0] for x in input]
self._test_func(input, func, expected)
def test_count(self):
"""Basic operation test for DStream.count."""
input = [range(5), range(10), range(20)]
def func(dstream):
return dstream.count()
expected = [[len(x)] for x in input]
self._test_func(input, func, expected)
def test_reduce(self):
"""Basic operation test for DStream.reduce."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.reduce(operator.add)
expected = [[reduce(operator.add, x)] for x in input]
self._test_func(input, func, expected)
def test_reduceByKey(self):
"""Basic operation test for DStream.reduceByKey."""
input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
[("", 1), ("", 1), ("", 1), ("", 1)],
[(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
def func(dstream):
return dstream.reduceByKey(operator.add)
expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
self._test_func(input, func, expected, sort=True)
def test_mapValues(self):
"""Basic operation test for DStream.mapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[(0, 4), (1, 1), (2, 2), (3, 3)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.mapValues(lambda x: x + 10)
expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
[(0, 14), (1, 11), (2, 12), (3, 13)],
[(1, 11), (2, 11), (3, 11), (4, 11)]]
self._test_func(input, func, expected, sort=True)
def test_flatMapValues(self):
"""Basic operation test for DStream.flatMapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[(0, 4), (1, 1), (2, 1), (3, 1)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.flatMapValues(lambda x: (x, x + 10))
expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
("c", 1), ("c", 11), ("d", 1), ("d", 11)],
[(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
[(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
self._test_func(input, func, expected)
def test_glom(self):
"""Basic operation test for DStream.glom."""
input = [range(1, 5), range(5, 9), range(9, 13)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
return dstream.glom()
expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
self._test_func(rdds, func, expected)
def test_mapPartitions(self):
"""Basic operation test for DStream.mapPartitions."""
input = [range(1, 5), range(5, 9), range(9, 13)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
def f(iterator):
yield sum(iterator)
return dstream.mapPartitions(f)
expected = [[3, 7], [11, 15], [19, 23]]
self._test_func(rdds, func, expected)
def test_countByValue(self):
"""Basic operation test for DStream.countByValue."""
input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]]
def func(dstream):
return dstream.countByValue()
expected = [[4], [4], [3]]
self._test_func(input, func, expected)
def test_groupByKey(self):
"""Basic operation test for DStream.groupByKey."""
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
def func(dstream):
return dstream.groupByKey().mapValues(list)
expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
[(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
[("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
self._test_func(input, func, expected, sort=True)
def test_combineByKey(self):
"""Basic operation test for DStream.combineByKey."""
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
def func(dstream):
def add(a, b):
return a + str(b)
return dstream.combineByKey(str, add, add)
expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
[(1, "111"), (2, "11"), (3, "1")],
[("a", "11"), ("b", "1"), ("", "111")]]
self._test_func(input, func, expected, sort=True)
def test_repartition(self):
input = [range(1, 5), range(5, 9)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
return dstream.repartition(1).glom()
expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
self._test_func(rdds, func, expected)
def test_union(self):
input1 = [range(3), range(5), range(6)]
input2 = [range(3, 6), range(5, 6)]
def func(d1, d2):
return d1.union(d2)
expected = [list(range(6)), list(range(6)), list(range(6))]
self._test_func(input1, func, expected, input2=input2)
def test_cogroup(self):
input = [[(1, 1), (2, 1), (3, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
input2 = [[(1, 2)],
[(4, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
def func(d1, d2):
return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
[(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
[("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
self._test_func(input, func, expected, sort=True, input2=input2)
def test_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.join(b)
expected = [[('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)
def test_left_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.leftOuterJoin(b)
expected = [[('a', (1, None)), ('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)
def test_right_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.rightOuterJoin(b)
expected = [[('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)
def test_full_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.fullOuterJoin(b)
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)
def test_update_state_by_key(self):
def updater(vs, s):
if not s:
s = []
s.extend(vs)
return s
input = [[('k', i)] for i in range(5)]
def func(dstream):
return dstream.updateStateByKey(updater)
expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
expected = [[('k', v)] for v in expected]
self._test_func(input, func, expected)
class WindowFunctionTests(PySparkStreamingTestCase):
timeout = 15
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
return dstream.window(1.5, .5).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
return dstream.countByWindow(1.5, .5)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.countByWindow(2.5, .5)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.countByValueAndWindow(2.5, .5)
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)
def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]
def func(dstream):
return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
self._test_func(input, func, expected)
def test_reduce_by_invalid_window(self):
input1 = [range(3), range(5), range(1), range(6)]
d1 = self.ssc.queueStream(input1)
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
class StreamingContextTests(PySparkStreamingTestCase):
duration = 0.1
def _add_input_stream(self):
inputs = [range(1, x) for x in range(101)]
stream = self.ssc.queueStream(inputs)
self._collect(stream, 1, block=False)
def test_stop_only_streaming_context(self):
self._add_input_stream()
self.ssc.start()
self.ssc.stop(False)
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
def test_stop_multiple_times(self):
self._add_input_stream()
self.ssc.start()
self.ssc.stop(False)
self.ssc.stop(False)
def test_queue_stream(self):
input = [list(range(i + 1)) for i in range(3)]
dstream = self.ssc.queueStream(input)
result = self._collect(dstream, 3)
self.assertEqual(input, result)
def test_text_file_stream(self):
d = tempfile.mkdtemp()
self.ssc = StreamingContext(self.sc, self.duration)
dstream2 = self.ssc.textFileStream(d).map(int)
result = self._collect(dstream2, 2, block=False)
self.ssc.start()
for name in ('a', 'b'):
time.sleep(1)
with open(os.path.join(d, name), "w") as f:
f.writelines(["%d\n" % i for i in range(10)])
self.wait_for(result, 2)
self.assertEqual([list(range(10)), list(range(10))], result)
def test_binary_records_stream(self):
d = tempfile.mkdtemp()
self.ssc = StreamingContext(self.sc, self.duration)
dstream = self.ssc.binaryRecordsStream(d, 10).map(
lambda v: struct.unpack("10b", bytes(v)))
result = self._collect(dstream, 2, block=False)
self.ssc.start()
for name in ('a', 'b'):
time.sleep(1)
with open(os.path.join(d, name), "wb") as f:
f.write(bytearray(range(10)))
self.wait_for(result, 2)
self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result])
def test_union(self):
input = [list(range(i + 1)) for i in range(3)]
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.queueStream(input)
dstream3 = self.ssc.union(dstream, dstream2)
result = self._collect(dstream3, 3)
expected = [i * 2 for i in input]
self.assertEqual(expected, result)
def test_transform(self):
dstream1 = self.ssc.queueStream([[1]])
dstream2 = self.ssc.queueStream([[2]])
dstream3 = self.ssc.queueStream([[3]])
def func(rdds):
rdd1, rdd2, rdd3 = rdds
return rdd2.union(rdd3).union(rdd1)
dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
self.assertEqual([2, 3, 1], self._take(dstream, 3))
class CheckpointTests(unittest.TestCase):
def test_get_or_create(self):
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"
def updater(vs, s):
return sum(vs, s or 0)
def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, 0.5)
dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
wc = dstream.updateStateByKey(updater)
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
wc.checkpoint(.5)
return ssc
cpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
def check_output(n):
while not os.listdir(outputd):
time.sleep(0.01)
time.sleep(1) # make sure mtime is larger than the previous one
with open(os.path.join(inputd, str(n)), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
while True:
p = os.path.join(outputd, max(os.listdir(outputd)))
if '_SUCCESS' not in os.listdir(p):
# not finished
time.sleep(0.01)
continue
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
d = ordd.values().map(int).collect()
if not d:
time.sleep(0.01)
continue
self.assertEqual(10, len(d))
s = set(d)
self.assertEqual(1, len(s))
m = s.pop()
if n > m:
continue
self.assertEqual(n, m)
break
check_output(1)
check_output(2)
ssc.stop(True, True)
time.sleep(1)
ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
check_output(3)
ssc.stop(True, True)
class KafkaStreamTests(PySparkStreamingTestCase):
timeout = 20 # seconds
duration = 1
def setUp(self):
super(KafkaStreamTests, self).setUp()
kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
.loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils")
self._kafkaTestUtils = kafkaTestUtilsClz.newInstance()
self._kafkaTestUtils.setup()
def tearDown(self):
if self._kafkaTestUtils is not None:
self._kafkaTestUtils.teardown()
self._kafkaTestUtils = None
super(KafkaStreamTests, self).tearDown()
def _randomTopic(self):
return "topic-%d" % random.randint(0, 10000)
def _validateStreamResult(self, sendData, stream):
result = {}
for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]),
sum(sendData.values()))):
result[i] = result.get(i, 0) + 1
self.assertEqual(sendData, result)
def _validateRddResult(self, sendData, rdd):
result = {}
for i in rdd.map(lambda x: x[1]).collect():
result[i] = result.get(i, 0) + 1
self.assertEqual(sendData, result)
def test_kafka_stream(self):
"""Test the Python Kafka stream API."""
topic = self._randomTopic()
sendData = {"a": 3, "b": 5, "c": 10}
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
"test-streaming-consumer", {topic: 1},
{"auto.offset.reset": "smallest"})
self._validateStreamResult(sendData, stream)
def test_kafka_direct_stream(self):
"""Test the Python direct Kafka stream API."""
topic = self._randomTopic()
sendData = {"a": 1, "b": 2, "c": 3}
kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
"auto.offset.reset": "smallest"}
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
self._validateStreamResult(sendData, stream)
@unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_from_offset(self):
"""Test the Python direct Kafka stream API with start offset specified."""
topic = self._randomTopic()
sendData = {"a": 1, "b": 2, "c": 3}
fromOffsets = {TopicAndPartition(topic, 0): long(0)}
kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
self._validateStreamResult(sendData, stream)
@unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd(self):
"""Test the Python direct Kafka RDD API."""
topic = self._randomTopic()
sendData = {"a": 1, "b": 2}
offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
self._validateRddResult(sendData, rdd)
@unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd_with_leaders(self):
"""Test the Python direct Kafka RDD API with leaders."""
topic = self._randomTopic()
sendData = {"a": 1, "b": 2, "c": 3}
offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
address = self._kafkaTestUtils.brokerAddress().split(":")
leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))}
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
self._validateRddResult(sendData, rdd)
if __name__ == "__main__":
SPARK_HOME = os.environ["SPARK_HOME"]
kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly")
jars = glob.glob(
os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar"))
if not jars:
raise Exception(
("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
"You need to build Spark with "
"'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or "
"'build/mvn package' before running this test")
elif len(jars) > 1:
raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please "
"remove all but one") % kafka_assembly_dir)
else:
os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0]
unittest.main()