spark-instrumented-optimizer/python/pyspark/tests/test_worker.py
Yuanjian Li dbbba80b3c [SPARK-26549][PYSPARK] Fix for python worker reuse take no effect for parallelize lazy iterable range
## What changes were proposed in this pull request?

During the follow-up work(#23435) for PySpark worker reuse scenario, we found that the worker reuse takes no effect for `sc.parallelize(xrange(...))`. It happened because of the specialize rdd.parallelize logic for xrange(introduced in #3264) generated data by lazy iterable range, which don't need to use the passed-in iterator. But this will break the end of stream checking in python worker and finally cause worker reuse takes no effect. See more details in [SPARK-26549](https://issues.apache.org/jira/browse/SPARK-26549) description.

We fix this by force using the passed-in iterator.

## How was this patch tested?
New UT in test_worker.py.

Closes #23470 from xuanyuanking/SPARK-26549.

Authored-by: Yuanjian Li <xyliyuanjian@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
2019-01-09 11:55:12 +08:00

168 lines
5.4 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
import tempfile
import threading
import time
from py4j.protocol import Py4JJavaError
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
if sys.version_info[0] >= 3:
xrange = range
class WorkerTests(ReusedPySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
temp.close()
path = temp.name
def sleep(x):
import os
import time
with open(path, 'w') as f:
f.write("%d %d" % (os.getppid(), os.getpid()))
time.sleep(100)
# start job in background thread
def run():
try:
self.sc.parallelize(range(1), 1).foreach(sleep)
except Exception:
pass
import threading
t = threading.Thread(target=run)
t.daemon = True
t.start()
daemon_pid, worker_pid = 0, 0
while True:
if os.path.exists(path):
with open(path) as f:
data = f.read().split(' ')
daemon_pid, worker_pid = map(int, data)
break
time.sleep(0.1)
# cancel jobs
self.sc.cancelAllJobs()
t.join()
for i in range(50):
try:
os.kill(worker_pid, 0)
time.sleep(0.1)
except OSError:
break # worker was killed
else:
self.fail("worker has not been killed after 5 seconds")
try:
os.kill(daemon_pid, 0)
except OSError:
self.fail("daemon had been killed")
# run a normal job
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())
def test_after_exception(self):
def raise_exception(_):
raise Exception()
rdd = self.sc.parallelize(xrange(100), 1)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
self.assertEqual(100, rdd.map(str).count())
def test_after_jvm_exception(self):
tempFile = tempfile.NamedTemporaryFile(delete=False)
tempFile.write(b"Hello World!")
tempFile.close()
data = self.sc.textFile(tempFile.name, 1)
filtered_data = data.filter(lambda x: True)
self.assertEqual(1, filtered_data.count())
os.unlink(tempFile.name)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: filtered_data.count())
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())
def test_accumulator_when_reuse_worker(self):
from pyspark.accumulators import INT_ACCUMULATOR_PARAM
acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
self.assertEqual(sum(range(100)), acc1.value)
acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)
def test_reuse_worker_after_take(self):
rdd = self.sc.parallelize(xrange(100000), 1)
self.assertEqual(0, rdd.first())
def count():
try:
rdd.count()
except Exception:
pass
t = threading.Thread(target=count)
t.daemon = True
t.start()
t.join(5)
self.assertTrue(not t.isAlive())
self.assertEqual(100000, rdd.count())
def test_with_different_versions_of_python(self):
rdd = self.sc.parallelize(range(10))
rdd.count()
version = self.sc.pythonVer
self.sc.pythonVer = "2.0"
try:
with QuietTest(self.sc):
self.assertRaises(Py4JJavaError, lambda: rdd.count())
finally:
self.sc.pythonVer = version
class WorkerReuseTest(PySparkTestCase):
def test_reuse_worker_of_parallelize_xrange(self):
rdd = self.sc.parallelize(xrange(20), 8)
previous_pids = rdd.map(lambda x: os.getpid()).collect()
current_pids = rdd.map(lambda x: os.getpid()).collect()
for pid in current_pids:
self.assertTrue(pid in previous_pids)
if __name__ == "__main__":
import unittest
from pyspark.tests.test_worker import *
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)