03306a6df3
## What changes were proposed in this pull request? This PR continues to break down a big large file into smaller files. See https://github.com/apache/spark/pull/23021. It targets to follow https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/tests.py` into ...: ``` pyspark ... ├── testing ... │ └── utils.py ├── tests │ ├── __init__.py │ ├── test_appsubmit.py │ ├── test_broadcast.py │ ├── test_conf.py │ ├── test_context.py │ ├── test_daemon.py │ ├── test_join.py │ ├── test_profiler.py │ ├── test_rdd.py │ ├── test_readwrite.py │ ├── test_serializers.py │ ├── test_shuffle.py │ ├── test_taskcontext.py │ ├── test_util.py │ └── test_worker.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and .`/run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.tests.test_context ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23033 from HyukjinKwon/SPARK-26036. Authored-by: hyukjinkwon <gurwls223@apache.org> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
269 lines
7.9 KiB
Python
269 lines
7.9 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import datetime
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from contextlib import contextmanager
|
|
|
|
from pyspark.sql import SparkSession
|
|
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
|
|
from pyspark.testing.utils import ReusedPySparkTestCase
|
|
from pyspark.util import _exception_message
|
|
|
|
|
|
pandas_requirement_message = None
|
|
try:
|
|
from pyspark.sql.utils import require_minimum_pandas_version
|
|
require_minimum_pandas_version()
|
|
except ImportError as e:
|
|
# If Pandas version requirement is not satisfied, skip related tests.
|
|
pandas_requirement_message = _exception_message(e)
|
|
|
|
pyarrow_requirement_message = None
|
|
try:
|
|
from pyspark.sql.utils import require_minimum_pyarrow_version
|
|
require_minimum_pyarrow_version()
|
|
except ImportError as e:
|
|
# If Arrow version requirement is not satisfied, skip related tests.
|
|
pyarrow_requirement_message = _exception_message(e)
|
|
|
|
test_not_compiled_message = None
|
|
try:
|
|
from pyspark.sql.utils import require_test_compiled
|
|
require_test_compiled()
|
|
except Exception as e:
|
|
test_not_compiled_message = _exception_message(e)
|
|
|
|
have_pandas = pandas_requirement_message is None
|
|
have_pyarrow = pyarrow_requirement_message is None
|
|
test_compiled = test_not_compiled_message is None
|
|
|
|
|
|
class UTCOffsetTimezone(datetime.tzinfo):
|
|
"""
|
|
Specifies timezone in UTC offset
|
|
"""
|
|
|
|
def __init__(self, offset=0):
|
|
self.ZERO = datetime.timedelta(hours=offset)
|
|
|
|
def utcoffset(self, dt):
|
|
return self.ZERO
|
|
|
|
def dst(self, dt):
|
|
return self.ZERO
|
|
|
|
|
|
class ExamplePointUDT(UserDefinedType):
|
|
"""
|
|
User-defined type (UDT) for ExamplePoint.
|
|
"""
|
|
|
|
@classmethod
|
|
def sqlType(self):
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return 'pyspark.sql.tests'
|
|
|
|
@classmethod
|
|
def scalaUDT(cls):
|
|
return 'org.apache.spark.sql.test.ExamplePointUDT'
|
|
|
|
def serialize(self, obj):
|
|
return [obj.x, obj.y]
|
|
|
|
def deserialize(self, datum):
|
|
return ExamplePoint(datum[0], datum[1])
|
|
|
|
|
|
class ExamplePoint:
|
|
"""
|
|
An example class to demonstrate UDT in Scala, Java, and Python.
|
|
"""
|
|
|
|
__UDT__ = ExamplePointUDT()
|
|
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __repr__(self):
|
|
return "ExamplePoint(%s,%s)" % (self.x, self.y)
|
|
|
|
def __str__(self):
|
|
return "(%s,%s)" % (self.x, self.y)
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, self.__class__) and \
|
|
other.x == self.x and other.y == self.y
|
|
|
|
|
|
class PythonOnlyUDT(UserDefinedType):
|
|
"""
|
|
User-defined type (UDT) for ExamplePoint.
|
|
"""
|
|
|
|
@classmethod
|
|
def sqlType(self):
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return '__main__'
|
|
|
|
def serialize(self, obj):
|
|
return [obj.x, obj.y]
|
|
|
|
def deserialize(self, datum):
|
|
return PythonOnlyPoint(datum[0], datum[1])
|
|
|
|
@staticmethod
|
|
def foo():
|
|
pass
|
|
|
|
@property
|
|
def props(self):
|
|
return {}
|
|
|
|
|
|
class PythonOnlyPoint(ExamplePoint):
|
|
"""
|
|
An example class to demonstrate UDT in only Python
|
|
"""
|
|
__UDT__ = PythonOnlyUDT()
|
|
|
|
|
|
class MyObject(object):
|
|
def __init__(self, key, value):
|
|
self.key = key
|
|
self.value = value
|
|
|
|
|
|
class SQLTestUtils(object):
|
|
"""
|
|
This util assumes the instance of this to have 'spark' attribute, having a spark session.
|
|
It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
|
|
the implementation of this class has 'spark' attribute.
|
|
"""
|
|
|
|
@contextmanager
|
|
def sql_conf(self, pairs):
|
|
"""
|
|
A convenient context manager to test some configuration specific logic. This sets
|
|
`value` to the configuration `key` and then restores it back when it exits.
|
|
"""
|
|
assert isinstance(pairs, dict), "pairs should be a dictionary."
|
|
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
|
|
|
|
keys = pairs.keys()
|
|
new_values = pairs.values()
|
|
old_values = [self.spark.conf.get(key, None) for key in keys]
|
|
for key, new_value in zip(keys, new_values):
|
|
self.spark.conf.set(key, new_value)
|
|
try:
|
|
yield
|
|
finally:
|
|
for key, old_value in zip(keys, old_values):
|
|
if old_value is None:
|
|
self.spark.conf.unset(key)
|
|
else:
|
|
self.spark.conf.set(key, old_value)
|
|
|
|
@contextmanager
|
|
def database(self, *databases):
|
|
"""
|
|
A convenient context manager to test with some specific databases. This drops the given
|
|
databases if it exists and sets current database to "default" when it exits.
|
|
"""
|
|
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
for db in databases:
|
|
self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
|
|
self.spark.catalog.setCurrentDatabase("default")
|
|
|
|
@contextmanager
|
|
def table(self, *tables):
|
|
"""
|
|
A convenient context manager to test with some specific tables. This drops the given tables
|
|
if it exists.
|
|
"""
|
|
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
for t in tables:
|
|
self.spark.sql("DROP TABLE IF EXISTS %s" % t)
|
|
|
|
@contextmanager
|
|
def tempView(self, *views):
|
|
"""
|
|
A convenient context manager to test with some specific views. This drops the given views
|
|
if it exists.
|
|
"""
|
|
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
for v in views:
|
|
self.spark.catalog.dropTempView(v)
|
|
|
|
@contextmanager
|
|
def function(self, *functions):
|
|
"""
|
|
A convenient context manager to test with some specific functions. This drops the given
|
|
functions if it exists.
|
|
"""
|
|
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
for f in functions:
|
|
self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
|
|
|
|
|
|
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(ReusedSQLTestCase, cls).setUpClass()
|
|
cls.spark = SparkSession(cls.sc)
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
os.unlink(cls.tempdir.name)
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
|
cls.df = cls.spark.createDataFrame(cls.testData)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super(ReusedSQLTestCase, cls).tearDownClass()
|
|
cls.spark.stop()
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
def assertPandasEqual(self, expected, result):
|
|
msg = ("DataFrames are not equal: " +
|
|
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
|
|
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
|
|
self.assertTrue(expected.equals(result), msg=msg)
|