a7a331df6e
## What changes were proposed in this pull request? This is the official first attempt to break huge single `tests.py` file - I did it locally before few times and gave up for some reasons. Now, currently it really makes the unittests super hard to read and difficult to check. To me, it even bothers me to to scroll down the big file. It's one single 7000 lines file! This is not only readability issue. Since one big test takes most of tests time, the tests don't run in parallel fully - although it will costs to start and stop the context. We could pick up one example and follow. Given my investigation, the current style looks closer to NumPy structure and looks easier to follow. Please see https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/sql/tests.py` into ...: ```bash pyspark ... ├── sql ... │ ├── tests # Includes all tests broken down from 'pyspark/sql/tests.py' │ │ │ # Each matchs to module in 'pyspark/sql'. Additionally, some logical group can │ │ │ # be added. For instance, 'test_arrow.py', 'test_datasources.py' ... │ │ ├── __init__.py │ │ ├── test_appsubmit.py │ │ ├── test_arrow.py │ │ ├── test_catalog.py │ │ ├── test_column.py │ │ ├── test_conf.py │ │ ├── test_context.py │ │ ├── test_dataframe.py │ │ ├── test_datasources.py │ │ ├── test_functions.py │ │ ├── test_group.py │ │ ├── test_pandas_udf.py │ │ ├── test_pandas_udf_grouped_agg.py │ │ ├── test_pandas_udf_grouped_map.py │ │ ├── test_pandas_udf_scalar.py │ │ ├── test_pandas_udf_window.py │ │ ├── test_readwriter.py │ │ ├── test_serde.py │ │ ├── test_session.py │ │ ├── test_streaming.py │ │ ├── test_types.py │ │ ├── test_udf.py │ │ └── test_utils.py ... ├── testing # Includes testing utils that can be used in unittests. │ ├── __init__.py │ └── sqlutils.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and `./run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ``` SPARK_TESTING=1 ./bin/pyspark pyspark.sql.tests.test_pandas_udf_scalar ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23021 from HyukjinKwon/SPARK-25344. Authored-by: hyukjinkwon <gurwls223@apache.org> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
158 lines
6.9 KiB
Python
158 lines
6.9 KiB
Python
# -*- encoding: utf-8 -*-
|
|
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import sys
|
|
|
|
from pyspark.sql import Column, Row
|
|
from pyspark.sql.types import *
|
|
from pyspark.sql.utils import AnalysisException
|
|
from pyspark.testing.sqlutils import ReusedSQLTestCase
|
|
|
|
|
|
class ColumnTests(ReusedSQLTestCase):
|
|
|
|
def test_column_name_encoding(self):
|
|
"""Ensure that created columns has `str` type consistently."""
|
|
columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
|
|
self.assertEqual(columns, ['name', 'age'])
|
|
self.assertTrue(isinstance(columns[0], str))
|
|
self.assertTrue(isinstance(columns[1], str))
|
|
|
|
def test_and_in_expression(self):
|
|
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
|
|
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
|
|
self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
|
|
self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
|
|
self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
|
|
self.assertRaises(ValueError, lambda: not self.df.key == 1)
|
|
|
|
def test_validate_column_types(self):
|
|
from pyspark.sql.functions import udf, to_json
|
|
from pyspark.sql.column import _to_java_column
|
|
|
|
self.assertTrue("Column" in _to_java_column("a").getClass().toString())
|
|
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
|
|
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
|
|
|
|
self.assertRaisesRegexp(
|
|
TypeError,
|
|
"Invalid argument, not a string or column",
|
|
lambda: _to_java_column(1))
|
|
|
|
class A():
|
|
pass
|
|
|
|
self.assertRaises(TypeError, lambda: _to_java_column(A()))
|
|
self.assertRaises(TypeError, lambda: _to_java_column([]))
|
|
|
|
self.assertRaisesRegexp(
|
|
TypeError,
|
|
"Invalid argument, not a string or column",
|
|
lambda: udf(lambda x: x)(None))
|
|
self.assertRaises(TypeError, lambda: to_json(1))
|
|
|
|
def test_column_operators(self):
|
|
ci = self.df.key
|
|
cs = self.df.value
|
|
c = ci == cs
|
|
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
|
|
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
|
|
self.assertTrue(all(isinstance(c, Column) for c in rcc))
|
|
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
|
|
self.assertTrue(all(isinstance(c, Column) for c in cb))
|
|
cbool = (ci & ci), (ci | ci), (~ci)
|
|
self.assertTrue(all(isinstance(c, Column) for c in cbool))
|
|
css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\
|
|
cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
|
|
self.assertTrue(all(isinstance(c, Column) for c in css))
|
|
self.assertTrue(isinstance(ci.cast(LongType()), Column))
|
|
self.assertRaisesRegexp(ValueError,
|
|
"Cannot apply 'in' operator against a column",
|
|
lambda: 1 in cs)
|
|
|
|
def test_column_getitem(self):
|
|
from pyspark.sql.functions import col
|
|
|
|
self.assertIsInstance(col("foo")[1:3], Column)
|
|
self.assertIsInstance(col("foo")[0], Column)
|
|
self.assertIsInstance(col("foo")["bar"], Column)
|
|
self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
|
|
|
|
def test_column_select(self):
|
|
df = self.df
|
|
self.assertEqual(self.testData, df.select("*").collect())
|
|
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
|
|
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
|
|
|
|
def test_access_column(self):
|
|
df = self.df
|
|
self.assertTrue(isinstance(df.key, Column))
|
|
self.assertTrue(isinstance(df['key'], Column))
|
|
self.assertTrue(isinstance(df[0], Column))
|
|
self.assertRaises(IndexError, lambda: df[2])
|
|
self.assertRaises(AnalysisException, lambda: df["bad_key"])
|
|
self.assertRaises(TypeError, lambda: df[{}])
|
|
|
|
def test_column_name_with_non_ascii(self):
|
|
if sys.version >= '3':
|
|
columnName = "数量"
|
|
self.assertTrue(isinstance(columnName, str))
|
|
else:
|
|
columnName = unicode("数量", "utf-8")
|
|
self.assertTrue(isinstance(columnName, unicode))
|
|
schema = StructType([StructField(columnName, LongType(), True)])
|
|
df = self.spark.createDataFrame([(1,)], schema)
|
|
self.assertEqual(schema, df.schema)
|
|
self.assertEqual("DataFrame[数量: bigint]", str(df))
|
|
self.assertEqual([("数量", 'bigint')], df.dtypes)
|
|
self.assertEqual(1, df.select("数量").first()[0])
|
|
self.assertEqual(1, df.select(df["数量"]).first()[0])
|
|
|
|
def test_field_accessor(self):
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
self.assertEqual(1, df.select(df.r["a"]).first()[0])
|
|
self.assertEqual(1, df.select(df["r.a"]).first()[0])
|
|
self.assertEqual("b", df.select(df.r["b"]).first()[0])
|
|
self.assertEqual("b", df.select(df["r.b"]).first()[0])
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
def test_bitwise_operations(self):
|
|
from pyspark.sql import functions
|
|
row = Row(a=170, b=75)
|
|
df = self.spark.createDataFrame([row])
|
|
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
|
|
self.assertEqual(170 & 75, result['(a & b)'])
|
|
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
|
|
self.assertEqual(170 | 75, result['(a | b)'])
|
|
result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
|
|
self.assertEqual(170 ^ 75, result['(a ^ b)'])
|
|
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
|
self.assertEqual(~75, result['~b'])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
from pyspark.sql.tests.test_column import *
|
|
|
|
try:
|
|
import xmlrunner
|
|
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
|
|
except ImportError:
|
|
unittest.main(verbosity=2)
|