# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import glob import os import struct import sys import unittest from pyspark import SparkContext, SparkConf have_scipy = False have_numpy = False try: import scipy.sparse have_scipy = True except: # No SciPy, but that's okay, we'll skip those tests pass try: import numpy as np have_numpy = True except: # No NumPy, but that's okay, we'll skip those tests pass SPARK_HOME = os.environ["SPARK_HOME"] def read_int(b): return struct.unpack("!i", b)[0] def write_int(i): return struct.pack("!i", i) class QuietTest(object): def __init__(self, sc): self.log4j = sc._jvm.org.apache.log4j def __enter__(self): self.old_level = self.log4j.LogManager.getRootLogger().getLevel() self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) def __exit__(self, exc_type, exc_val, exc_tb): self.log4j.LogManager.getRootLogger().setLevel(self.old_level) class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() sys.path = self._old_sys_path class ReusedPySparkTestCase(unittest.TestCase): @classmethod def conf(cls): """ Override this in subclasses to supply a more specific conf """ return SparkConf() @classmethod def setUpClass(cls): cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) @classmethod def tearDownClass(cls): cls.sc.stop() class ByteArrayOutput(object): def __init__(self): self.buffer = bytearray() def write(self, b): self.buffer += b def close(self): pass def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): # Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can # vary for SBT or Maven specifically. See also SPARK-26856 project_full_path = os.path.join( os.environ["SPARK_HOME"], project_relative_path) # We should ignore the following jars ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") # Search jar in the project dir using the jar name_prefix for both sbt build and maven # build because the artifact jars are in different directories. sbt_build = glob.glob(os.path.join( project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix)) maven_build = glob.glob(os.path.join( project_full_path, "target/%s*.jar" % mvn_jar_name_prefix)) jar_paths = sbt_build + maven_build jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] if not jars: return None elif len(jars) > 1: raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) else: return jars[0]