# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import sys import time import unittest from pyspark.serializers import read_int class DaemonTests(unittest.TestCase): def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) sock.connect(('127.0.0.1', port)) # send a split index of -1 to shutdown the worker sock.send(b"\xFF\xFF\xFF\xFF") sock.close() return True def do_termination_test(self, terminator): from subprocess import Popen, PIPE from errno import ECONNREFUSED # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "..", "daemon.py") python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) # daemon should accept connections self.assertTrue(self.connect(port)) # wait worker process spawned from daemon exit. time.sleep(1) # request shutdown terminator(daemon) time.sleep(1) # daemon should no longer accept connections try: self.connect(port) except EnvironmentError as exception: self.assertEqual(exception.errno, ECONNREFUSED) else: self.fail("Expected EnvironmentError to be raised") def test_termination_stdin(self): """Ensure that daemon and workers terminate when stdin is closed.""" self.do_termination_test(lambda daemon: daemon.stdin.close()) def test_termination_sigterm(self): """Ensure that daemon and workers terminate on SIGTERM.""" from signal import SIGTERM self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) if __name__ == "__main__": from pyspark.tests.test_daemon import * # noqa: F401 try: import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None unittest.main(testRunner=testRunner, verbosity=2)