# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import struct import sys import unittest from pyspark import SparkContext, SparkConf have_scipy = False have_numpy = False try: import scipy.sparse have_scipy = True except: # No SciPy, but that's okay, we'll skip those tests pass try: import numpy as np have_numpy = True except: # No NumPy, but that's okay, we'll skip those tests pass SPARK_HOME = os.environ["SPARK_HOME"] def read_int(b): return struct.unpack("!i", b)[0] def write_int(i): return struct.pack("!i", i) class QuietTest(object): def __init__(self, sc): self.log4j = sc._jvm.org.apache.log4j def __enter__(self): self.old_level = self.log4j.LogManager.getRootLogger().getLevel() self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) def __exit__(self, exc_type, exc_val, exc_tb): self.log4j.LogManager.getRootLogger().setLevel(self.old_level) class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() sys.path = self._old_sys_path class ReusedPySparkTestCase(unittest.TestCase): @classmethod def conf(cls): """ Override this in subclasses to supply a more specific conf """ return SparkConf() @classmethod def setUpClass(cls): cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) @classmethod def tearDownClass(cls): cls.sc.stop() class ByteArrayOutput(object): def __init__(self): self.buffer = bytearray() def write(self, b): self.buffer += b def close(self): pass