2018-08-13 22:35:34 -04:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
import os
|
|
|
|
import random
|
2019-08-05 07:18:53 -04:00
|
|
|
import time
|
2018-08-13 22:35:34 -04:00
|
|
|
import tempfile
|
|
|
|
import unittest
|
|
|
|
|
2018-11-14 23:30:52 -05:00
|
|
|
from pyspark import SparkConf, SparkContext
|
2018-08-13 22:35:34 -04:00
|
|
|
from pyspark.java_gateway import launch_gateway
|
|
|
|
from pyspark.serializers import ChunkedStream
|
|
|
|
|
|
|
|
|
|
|
|
class BroadcastTest(unittest.TestCase):
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
if getattr(self, "sc", None) is not None:
|
|
|
|
self.sc.stop()
|
|
|
|
self.sc = None
|
|
|
|
|
|
|
|
def _test_encryption_helper(self, vs):
|
|
|
|
"""
|
|
|
|
Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
|
|
|
|
value is the same when it's read in the executors. Also makes sure there are no task
|
|
|
|
failures.
|
|
|
|
"""
|
|
|
|
bs = [self.sc.broadcast(value=v) for v in vs]
|
|
|
|
exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
|
|
|
|
for ev in exec_values:
|
|
|
|
self.assertEqual(ev, vs)
|
|
|
|
# make sure there are no task failures
|
|
|
|
status = self.sc.statusTracker()
|
|
|
|
for jid in status.getJobIdsForGroup():
|
|
|
|
for sid in status.getJobInfo(jid).stageIds:
|
|
|
|
stage_info = status.getStageInfo(sid)
|
|
|
|
self.assertEqual(0, stage_info.numFailedTasks)
|
|
|
|
|
|
|
|
def _test_multiple_broadcasts(self, *extra_confs):
|
|
|
|
"""
|
|
|
|
Test broadcast variables make it OK to the executors. Tests multiple broadcast variables,
|
|
|
|
and also multiple jobs.
|
|
|
|
"""
|
|
|
|
conf = SparkConf()
|
|
|
|
for key, value in extra_confs:
|
|
|
|
conf.set(key, value)
|
|
|
|
conf.setMaster("local-cluster[2,1,1024]")
|
|
|
|
self.sc = SparkContext(conf=conf)
|
|
|
|
self._test_encryption_helper([5])
|
|
|
|
self._test_encryption_helper([5, 10, 20])
|
|
|
|
|
|
|
|
def test_broadcast_with_encryption(self):
|
|
|
|
self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
|
|
|
|
|
|
|
|
def test_broadcast_no_encryption(self):
|
|
|
|
self._test_multiple_broadcasts()
|
|
|
|
|
[SPARK-26201] Fix python broadcast with encryption
## What changes were proposed in this pull request?
Python with rpc and disk encryption enabled along with a python broadcast variable and just read the value back on the driver side the job failed with:
Traceback (most recent call last): File "broadcast.py", line 37, in <module> words_new.value File "/pyspark.zip/pyspark/broadcast.py", line 137, in value File "pyspark.zip/pyspark/broadcast.py", line 122, in load_from_path File "pyspark.zip/pyspark/broadcast.py", line 128, in load EOFError: Ran out of input
To reproduce use configs: --conf spark.network.crypto.enabled=true --conf spark.io.encryption.enabled=true
Code:
words_new = sc.broadcast(["scala", "java", "hadoop", "spark", "akka"])
words_new.value
print(words_new.value)
## How was this patch tested?
words_new = sc.broadcast([“scala”, “java”, “hadoop”, “spark”, “akka”])
textFile = sc.textFile(“README.md”)
wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word + words_new.value[1], 1)).reduceByKey(lambda a, b: a+b)
count = wordCounts.count()
print(count)
words_new.value
print(words_new.value)
Closes #23166 from redsanket/SPARK-26201.
Authored-by: schintap <schintap@oath.com>
Signed-off-by: Thomas Graves <tgraves@apache.org>
2018-11-30 13:48:56 -05:00
|
|
|
def _test_broadcast_on_driver(self, *extra_confs):
|
|
|
|
conf = SparkConf()
|
|
|
|
for key, value in extra_confs:
|
|
|
|
conf.set(key, value)
|
|
|
|
conf.setMaster("local-cluster[2,1,1024]")
|
|
|
|
self.sc = SparkContext(conf=conf)
|
|
|
|
bs = self.sc.broadcast(value=5)
|
|
|
|
self.assertEqual(5, bs.value)
|
|
|
|
|
|
|
|
def test_broadcast_value_driver_no_encryption(self):
|
|
|
|
self._test_broadcast_on_driver()
|
|
|
|
|
|
|
|
def test_broadcast_value_driver_encryption(self):
|
|
|
|
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
|
|
|
|
|
2019-08-05 07:18:53 -04:00
|
|
|
def test_broadcast_value_against_gc(self):
|
|
|
|
# Test broadcast value against gc.
|
|
|
|
conf = SparkConf()
|
|
|
|
conf.setMaster("local[1,1]")
|
|
|
|
conf.set("spark.memory.fraction", "0.0001")
|
|
|
|
self.sc = SparkContext(conf=conf)
|
|
|
|
b = self.sc.broadcast([100])
|
|
|
|
try:
|
|
|
|
res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
|
|
|
|
self.assertEqual([0], res)
|
|
|
|
self.sc._jvm.java.lang.System.gc()
|
|
|
|
time.sleep(5)
|
|
|
|
res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
|
|
|
|
self.assertEqual([100], res)
|
|
|
|
finally:
|
|
|
|
b.destroy()
|
|
|
|
|
2018-08-13 22:35:34 -04:00
|
|
|
|
|
|
|
class BroadcastFrameProtocolTest(unittest.TestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
gateway = launch_gateway(SparkConf())
|
|
|
|
cls._jvm = gateway.jvm
|
|
|
|
cls.longMessage = True
|
|
|
|
random.seed(42)
|
|
|
|
|
|
|
|
def _test_chunked_stream(self, data, py_buf_size):
|
|
|
|
# write data using the chunked protocol from python.
|
|
|
|
chunked_file = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
dechunked_file = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
dechunked_file.close()
|
|
|
|
try:
|
|
|
|
out = ChunkedStream(chunked_file, py_buf_size)
|
|
|
|
out.write(data)
|
|
|
|
out.close()
|
|
|
|
# now try to read it in java
|
|
|
|
jin = self._jvm.java.io.FileInputStream(chunked_file.name)
|
|
|
|
jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
|
|
|
|
self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
|
|
|
|
# java should have decoded it back to the original data
|
|
|
|
self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
|
|
|
|
with open(dechunked_file.name, "rb") as f:
|
|
|
|
byte = f.read(1)
|
|
|
|
idx = 0
|
|
|
|
while byte:
|
|
|
|
self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
|
|
|
|
byte = f.read(1)
|
|
|
|
idx += 1
|
|
|
|
finally:
|
|
|
|
os.unlink(chunked_file.name)
|
|
|
|
os.unlink(dechunked_file.name)
|
|
|
|
|
|
|
|
def test_chunked_stream(self):
|
|
|
|
def random_bytes(n):
|
|
|
|
return bytearray(random.getrandbits(8) for _ in range(n))
|
|
|
|
for data_length in [1, 10, 100, 10000]:
|
|
|
|
for buffer_length in [1, 2, 5, 8192]:
|
|
|
|
self._test_chunked_stream(random_bytes(data_length), buffer_length)
|
|
|
|
|
2018-11-14 23:30:52 -05:00
|
|
|
|
2018-08-13 22:35:34 -04:00
|
|
|
if __name__ == '__main__':
|
2020-08-08 11:51:57 -04:00
|
|
|
from pyspark.tests.test_broadcast import * # noqa: F401
|
2018-11-14 23:30:52 -05:00
|
|
|
|
|
|
|
try:
|
|
|
|
import xmlrunner
|
2019-06-23 20:58:17 -04:00
|
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
2018-11-14 23:30:52 -05:00
|
|
|
except ImportError:
|
|
|
|
testRunner = None
|
|
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|