[SPARK-35339][PYTHON] Improve unit tests for data-type-based basic operations

### What changes were proposed in this pull request?

Improve unit tests for data-type-based basic operations by:
- removing redundant test cases
- adding `astype` test for ExtensionDtypes

### Why are the changes needed?

Some test cases for basic operations are duplicated after introducing data-type-based basic operations. The PR is proposed to remove redundant test cases.
`astype` is not tested for ExtensionDtypes, which will be adjusted in this PR as well.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

Closes #33095 from xinrong-databricks/datatypeops_test.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-07-01 17:37:32 -07:00 committed by Takuya UESHIN
parent fceabe2372
commit 95d94948c5
5 changed files with 98 additions and 54 deletions

View file

@ -572,10 +572,19 @@ class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def test_astype(self): def test_astype(self):
pser = self.pser pser = self.pser
psser = self.psser psser = self.psser
# TODO(SPARK-35976): [True, False, <NA>] is returned in pandas
self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist()) self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist())
self.assert_eq(pser.astype("category"), psser.astype("category")) self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[False, True]) cat_type = CategoricalDtype(categories=[False, True])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
for dtype in self.extension_dtypes:
if dtype in self.fractional_extension_dtypes:
# A pandas boolean extension series cannot be casted to fractional extension dtypes
self.assert_eq([1.0, 0.0, np.nan], self.psser.astype(dtype).tolist())
else:
self.check_extension(pser.astype(dtype), psser.astype(dtype))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -322,8 +322,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property @property
def intergral_extension_psers(self): def intergral_extension_psers(self):
dtypes = ["Int8", "Int16", "Int32", "Int64"] return [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in self.integral_extension_dtypes]
return [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in dtypes]
@property @property
def intergral_extension_pssers(self): def intergral_extension_pssers(self):
@ -342,6 +341,11 @@ class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
for pser, psser in self.intergral_extension_pser_psser_pairs: for pser, psser in self.intergral_extension_pser_psser_pairs:
self.assert_eq(pser.isnull(), psser.isnull()) self.assert_eq(pser.isnull(), psser.isnull())
def test_astype(self):
for pser, psser in self.intergral_extension_pser_psser_pairs:
for dtype in self.extension_dtypes:
self.check_extension(pser.astype(dtype), psser.astype(dtype))
@unittest.skipIf( @unittest.skipIf(
not extension_float_dtypes_available, "pandas extension float dtypes are not available" not extension_float_dtypes_available, "pandas extension float dtypes are not available"
@ -349,8 +353,10 @@ class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property @property
def fractional_extension_psers(self): def fractional_extension_psers(self):
dtypes = ["Float32", "Float64"] return [
return [pd.Series([0.1, 0.2, 0.3, None], dtype=dtype) for dtype in dtypes] pd.Series([0.1, 0.2, 0.3, None], dtype=dtype)
for dtype in self.fractional_extension_dtypes
]
@property @property
def fractional_extension_pssers(self): def fractional_extension_pssers(self):
@ -369,6 +375,11 @@ class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
for pser, psser in self.fractional_extension_pser_psser_pairs: for pser, psser in self.fractional_extension_pser_psser_pairs:
self.assert_eq(pser.isnull(), psser.isnull()) self.assert_eq(pser.isnull(), psser.isnull())
def test_astype(self):
for pser, psser in self.fractional_extension_pser_psser_pairs:
for dtype in self.extension_dtypes:
self.check_extension(pser.astype(dtype), psser.astype(dtype))
if __name__ == "__main__": if __name__ == "__main__":
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401 from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401

View file

@ -27,6 +27,9 @@ from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils
from pyspark.pandas.typedef.typehints import extension_object_dtypes_available from pyspark.pandas.typedef.typehints import extension_object_dtypes_available
from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.pandasutils import PandasOnSparkTestCase
if extension_object_dtypes_available:
from pandas import StringDtype
class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils): class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property @property
@ -217,6 +220,20 @@ class StringExtensionOpsTest(StringOpsTest, PandasOnSparkTestCase, TestCasesUtil
def test_isnull(self): def test_isnull(self):
self.assert_eq(self.pser.isnull(), self.psser.isnull()) self.assert_eq(self.pser.isnull(), self.psser.isnull())
def test_astype(self):
pser = self.pser
psser = self.psser
# TODO(SPARK-35976): [x, y, z, <NA>] is returned in pandas
self.assert_eq(["x", "y", "z", "None"], self.psser.astype(str).tolist())
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=["x", "y"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
for dtype in self.object_extension_dtypes:
if dtype in ["string", StringDtype()]:
self.check_extension(pser.astype(dtype), psser.astype(dtype))
if __name__ == "__main__": if __name__ == "__main__":
from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401 from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401

View file

@ -25,6 +25,21 @@ import pandas as pd
import pyspark.pandas as ps import pyspark.pandas as ps
from pyspark.pandas.typedef import extension_dtypes from pyspark.pandas.typedef import extension_dtypes
from pyspark.pandas.typedef.typehints import (
extension_dtypes_available,
extension_float_dtypes_available,
extension_object_dtypes_available,
)
if extension_dtypes_available:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
if extension_float_dtypes_available:
from pandas import Float32Dtype, Float64Dtype
if extension_object_dtypes_available:
from pandas import BooleanDtype, StringDtype
class TestCasesUtils(object): class TestCasesUtils(object):
"""A utility holding common test cases for arithmetic operations of different data types.""" """A utility holding common test cases for arithmetic operations of different data types."""
@ -81,6 +96,47 @@ class TestCasesUtils(object):
def pser_psser_pairs(self): def pser_psser_pairs(self):
return zip(self.psers, self.pssers) return zip(self.psers, self.pssers)
@property
def object_extension_dtypes(self):
return (
["boolean", "string", BooleanDtype(), StringDtype()]
if extension_object_dtypes_available
else []
)
@property
def fractional_extension_dtypes(self):
return (
["Float32", "Float64", Float32Dtype(), Float64Dtype()]
if extension_float_dtypes_available
else []
)
@property
def integral_extension_dtypes(self):
return (
[
"Int8",
"Int16",
"Int32",
"Int64",
Int8Dtype(),
Int16Dtype(),
Int32Dtype(),
Int64Dtype(),
]
if extension_dtypes_available
else []
)
@property
def extension_dtypes(self):
return (
self.object_extension_dtypes
+ self.fractional_extension_dtypes
+ self.integral_extension_dtypes
)
def check_extension(self, psser, pser): def check_extension(self, psser, pser):
""" """
Compare `psser` and `pser` of numeric ExtensionDtypes. Compare `psser` and `pser` of numeric ExtensionDtypes.

View file

@ -57,8 +57,6 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
pser = self.pser pser = self.pser
psser = self.psser psser = self.psser
self.assert_eq(psser + 1, pser + 1)
self.assert_eq(1 + psser, 1 + pser)
self.assert_eq(psser + 1 + 10 * psser, pser + 1 + 10 * pser) self.assert_eq(psser + 1 + 10 * psser, pser + 1 + 10 * pser)
self.assert_eq(psser + 1 + 10 * psser.index, pser + 1 + 10 * pser.index) self.assert_eq(psser + 1 + 10 * psser.index, pser + 1 + 10 * pser.index)
self.assert_eq(psser.index + 1 + 10 * psser, pser.index + 1 + 10 * pser) self.assert_eq(psser.index + 1 + 10 * psser, pser.index + 1 + 10 * pser)
@ -92,51 +90,6 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
else: else:
self.assert_eq(psser, pser) self.assert_eq(psser, pser)
@unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available")
def test_extension_dtypes(self):
for pser in [
pd.Series([1, 2, None, 4], dtype="Int8"),
pd.Series([1, 2, None, 4], dtype="Int16"),
pd.Series([1, 2, None, 4], dtype="Int32"),
pd.Series([1, 2, None, 4], dtype="Int64"),
]:
psser = ps.from_pandas(pser)
self._check_extension(psser, pser)
self._check_extension(psser + psser, pser + pser)
@unittest.skipIf(
not extension_object_dtypes_available, "pandas extension object dtypes are not available"
)
def test_extension_object_dtypes(self):
# string
pser = pd.Series(["a", None, "c", "d"], dtype="string")
psser = ps.from_pandas(pser)
self._check_extension(psser, pser)
# boolean
pser = pd.Series([True, False, True, None], dtype="boolean")
psser = ps.from_pandas(pser)
self._check_extension(psser, pser)
self._check_extension(psser & psser, pser & pser)
self._check_extension(psser | psser, pser | pser)
@unittest.skipIf(
not extension_float_dtypes_available, "pandas extension float dtypes are not available"
)
def test_extension_float_dtypes(self):
for pser in [
pd.Series([1.0, 2.0, None, 4.0], dtype="Float32"),
pd.Series([1.0, 2.0, None, 4.0], dtype="Float64"),
]:
psser = ps.from_pandas(pser)
self._check_extension(psser, pser)
self._check_extension(psser + 1, pser + 1)
self._check_extension(psser + psser, pser + pser)
def test_empty_series(self): def test_empty_series(self):
pser_a = pd.Series([], dtype="i1") pser_a = pd.Series([], dtype="i1")
pser_b = pd.Series([], dtype="str") pser_b = pd.Series([], dtype="str")
@ -872,18 +825,16 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(psser.nlargest(), pser.nlargest()) self.assert_eq(psser.nlargest(), pser.nlargest())
self.assert_eq((psser + 1).nlargest(), (pser + 1).nlargest()) self.assert_eq((psser + 1).nlargest(), (pser + 1).nlargest())
def test_isnull(self): def test_notnull(self):
pser = pd.Series([1, 2, 3, 4, np.nan, 6], name="x") pser = pd.Series([1, 2, 3, 4, np.nan, 6], name="x")
psser = ps.from_pandas(pser) psser = ps.from_pandas(pser)
self.assert_eq(psser.notnull(), pser.notnull()) self.assert_eq(psser.notnull(), pser.notnull())
self.assert_eq(psser.isnull(), pser.isnull())
pser = self.pser pser = self.pser
psser = self.psser psser = self.psser
self.assert_eq(psser.notnull(), pser.notnull()) self.assert_eq(psser.notnull(), pser.notnull())
self.assert_eq(psser.isnull(), pser.isnull())
def test_all(self): def test_all(self):
for pser in [ for pser in [