[SPARK-35339][PYTHON] Improve unit tests for data-type-based basic operations
### What changes were proposed in this pull request? Improve unit tests for data-type-based basic operations by: - removing redundant test cases - adding `astype` test for ExtensionDtypes ### Why are the changes needed? Some test cases for basic operations are duplicated after introducing data-type-based basic operations. The PR is proposed to remove redundant test cases. `astype` is not tested for ExtensionDtypes, which will be adjusted in this PR as well. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #33095 from xinrong-databricks/datatypeops_test. Authored-by: Xinrong Meng <xinrong.meng@databricks.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
fceabe2372
commit
95d94948c5
|
@ -572,10 +572,19 @@ class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
def test_astype(self):
|
def test_astype(self):
|
||||||
pser = self.pser
|
pser = self.pser
|
||||||
psser = self.psser
|
psser = self.psser
|
||||||
|
|
||||||
|
# TODO(SPARK-35976): [True, False, <NA>] is returned in pandas
|
||||||
self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist())
|
self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist())
|
||||||
|
|
||||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||||
cat_type = CategoricalDtype(categories=[False, True])
|
cat_type = CategoricalDtype(categories=[False, True])
|
||||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||||
|
for dtype in self.extension_dtypes:
|
||||||
|
if dtype in self.fractional_extension_dtypes:
|
||||||
|
# A pandas boolean extension series cannot be casted to fractional extension dtypes
|
||||||
|
self.assert_eq([1.0, 0.0, np.nan], self.psser.astype(dtype).tolist())
|
||||||
|
else:
|
||||||
|
self.check_extension(pser.astype(dtype), psser.astype(dtype))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -322,8 +322,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
@property
|
@property
|
||||||
def intergral_extension_psers(self):
|
def intergral_extension_psers(self):
|
||||||
dtypes = ["Int8", "Int16", "Int32", "Int64"]
|
return [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in self.integral_extension_dtypes]
|
||||||
return [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in dtypes]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def intergral_extension_pssers(self):
|
def intergral_extension_pssers(self):
|
||||||
|
@ -342,6 +341,11 @@ class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
for pser, psser in self.intergral_extension_pser_psser_pairs:
|
for pser, psser in self.intergral_extension_pser_psser_pairs:
|
||||||
self.assert_eq(pser.isnull(), psser.isnull())
|
self.assert_eq(pser.isnull(), psser.isnull())
|
||||||
|
|
||||||
|
def test_astype(self):
|
||||||
|
for pser, psser in self.intergral_extension_pser_psser_pairs:
|
||||||
|
for dtype in self.extension_dtypes:
|
||||||
|
self.check_extension(pser.astype(dtype), psser.astype(dtype))
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not extension_float_dtypes_available, "pandas extension float dtypes are not available"
|
not extension_float_dtypes_available, "pandas extension float dtypes are not available"
|
||||||
|
@ -349,8 +353,10 @@ class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
@property
|
@property
|
||||||
def fractional_extension_psers(self):
|
def fractional_extension_psers(self):
|
||||||
dtypes = ["Float32", "Float64"]
|
return [
|
||||||
return [pd.Series([0.1, 0.2, 0.3, None], dtype=dtype) for dtype in dtypes]
|
pd.Series([0.1, 0.2, 0.3, None], dtype=dtype)
|
||||||
|
for dtype in self.fractional_extension_dtypes
|
||||||
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fractional_extension_pssers(self):
|
def fractional_extension_pssers(self):
|
||||||
|
@ -369,6 +375,11 @@ class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
for pser, psser in self.fractional_extension_pser_psser_pairs:
|
for pser, psser in self.fractional_extension_pser_psser_pairs:
|
||||||
self.assert_eq(pser.isnull(), psser.isnull())
|
self.assert_eq(pser.isnull(), psser.isnull())
|
||||||
|
|
||||||
|
def test_astype(self):
|
||||||
|
for pser, psser in self.fractional_extension_pser_psser_pairs:
|
||||||
|
for dtype in self.extension_dtypes:
|
||||||
|
self.check_extension(pser.astype(dtype), psser.astype(dtype))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401
|
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401
|
||||||
|
|
|
@ -27,6 +27,9 @@ from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils
|
||||||
from pyspark.pandas.typedef.typehints import extension_object_dtypes_available
|
from pyspark.pandas.typedef.typehints import extension_object_dtypes_available
|
||||||
from pyspark.testing.pandasutils import PandasOnSparkTestCase
|
from pyspark.testing.pandasutils import PandasOnSparkTestCase
|
||||||
|
|
||||||
|
if extension_object_dtypes_available:
|
||||||
|
from pandas import StringDtype
|
||||||
|
|
||||||
|
|
||||||
class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||||
@property
|
@property
|
||||||
|
@ -217,6 +220,20 @@ class StringExtensionOpsTest(StringOpsTest, PandasOnSparkTestCase, TestCasesUtil
|
||||||
def test_isnull(self):
|
def test_isnull(self):
|
||||||
self.assert_eq(self.pser.isnull(), self.psser.isnull())
|
self.assert_eq(self.pser.isnull(), self.psser.isnull())
|
||||||
|
|
||||||
|
def test_astype(self):
|
||||||
|
pser = self.pser
|
||||||
|
psser = self.psser
|
||||||
|
|
||||||
|
# TODO(SPARK-35976): [x, y, z, <NA>] is returned in pandas
|
||||||
|
self.assert_eq(["x", "y", "z", "None"], self.psser.astype(str).tolist())
|
||||||
|
|
||||||
|
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||||
|
cat_type = CategoricalDtype(categories=["x", "y"])
|
||||||
|
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||||
|
for dtype in self.object_extension_dtypes:
|
||||||
|
if dtype in ["string", StringDtype()]:
|
||||||
|
self.check_extension(pser.astype(dtype), psser.astype(dtype))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401
|
from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401
|
||||||
|
|
|
@ -25,6 +25,21 @@ import pandas as pd
|
||||||
import pyspark.pandas as ps
|
import pyspark.pandas as ps
|
||||||
from pyspark.pandas.typedef import extension_dtypes
|
from pyspark.pandas.typedef import extension_dtypes
|
||||||
|
|
||||||
|
from pyspark.pandas.typedef.typehints import (
|
||||||
|
extension_dtypes_available,
|
||||||
|
extension_float_dtypes_available,
|
||||||
|
extension_object_dtypes_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
if extension_dtypes_available:
|
||||||
|
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
|
||||||
|
|
||||||
|
if extension_float_dtypes_available:
|
||||||
|
from pandas import Float32Dtype, Float64Dtype
|
||||||
|
|
||||||
|
if extension_object_dtypes_available:
|
||||||
|
from pandas import BooleanDtype, StringDtype
|
||||||
|
|
||||||
|
|
||||||
class TestCasesUtils(object):
|
class TestCasesUtils(object):
|
||||||
"""A utility holding common test cases for arithmetic operations of different data types."""
|
"""A utility holding common test cases for arithmetic operations of different data types."""
|
||||||
|
@ -81,6 +96,47 @@ class TestCasesUtils(object):
|
||||||
def pser_psser_pairs(self):
|
def pser_psser_pairs(self):
|
||||||
return zip(self.psers, self.pssers)
|
return zip(self.psers, self.pssers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def object_extension_dtypes(self):
|
||||||
|
return (
|
||||||
|
["boolean", "string", BooleanDtype(), StringDtype()]
|
||||||
|
if extension_object_dtypes_available
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fractional_extension_dtypes(self):
|
||||||
|
return (
|
||||||
|
["Float32", "Float64", Float32Dtype(), Float64Dtype()]
|
||||||
|
if extension_float_dtypes_available
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def integral_extension_dtypes(self):
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
"Int8",
|
||||||
|
"Int16",
|
||||||
|
"Int32",
|
||||||
|
"Int64",
|
||||||
|
Int8Dtype(),
|
||||||
|
Int16Dtype(),
|
||||||
|
Int32Dtype(),
|
||||||
|
Int64Dtype(),
|
||||||
|
]
|
||||||
|
if extension_dtypes_available
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extension_dtypes(self):
|
||||||
|
return (
|
||||||
|
self.object_extension_dtypes
|
||||||
|
+ self.fractional_extension_dtypes
|
||||||
|
+ self.integral_extension_dtypes
|
||||||
|
)
|
||||||
|
|
||||||
def check_extension(self, psser, pser):
|
def check_extension(self, psser, pser):
|
||||||
"""
|
"""
|
||||||
Compare `psser` and `pser` of numeric ExtensionDtypes.
|
Compare `psser` and `pser` of numeric ExtensionDtypes.
|
||||||
|
|
|
@ -57,8 +57,6 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
|
||||||
pser = self.pser
|
pser = self.pser
|
||||||
psser = self.psser
|
psser = self.psser
|
||||||
|
|
||||||
self.assert_eq(psser + 1, pser + 1)
|
|
||||||
self.assert_eq(1 + psser, 1 + pser)
|
|
||||||
self.assert_eq(psser + 1 + 10 * psser, pser + 1 + 10 * pser)
|
self.assert_eq(psser + 1 + 10 * psser, pser + 1 + 10 * pser)
|
||||||
self.assert_eq(psser + 1 + 10 * psser.index, pser + 1 + 10 * pser.index)
|
self.assert_eq(psser + 1 + 10 * psser.index, pser + 1 + 10 * pser.index)
|
||||||
self.assert_eq(psser.index + 1 + 10 * psser, pser.index + 1 + 10 * pser)
|
self.assert_eq(psser.index + 1 + 10 * psser, pser.index + 1 + 10 * pser)
|
||||||
|
@ -92,51 +90,6 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
|
||||||
else:
|
else:
|
||||||
self.assert_eq(psser, pser)
|
self.assert_eq(psser, pser)
|
||||||
|
|
||||||
@unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available")
|
|
||||||
def test_extension_dtypes(self):
|
|
||||||
for pser in [
|
|
||||||
pd.Series([1, 2, None, 4], dtype="Int8"),
|
|
||||||
pd.Series([1, 2, None, 4], dtype="Int16"),
|
|
||||||
pd.Series([1, 2, None, 4], dtype="Int32"),
|
|
||||||
pd.Series([1, 2, None, 4], dtype="Int64"),
|
|
||||||
]:
|
|
||||||
psser = ps.from_pandas(pser)
|
|
||||||
|
|
||||||
self._check_extension(psser, pser)
|
|
||||||
self._check_extension(psser + psser, pser + pser)
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not extension_object_dtypes_available, "pandas extension object dtypes are not available"
|
|
||||||
)
|
|
||||||
def test_extension_object_dtypes(self):
|
|
||||||
# string
|
|
||||||
pser = pd.Series(["a", None, "c", "d"], dtype="string")
|
|
||||||
psser = ps.from_pandas(pser)
|
|
||||||
|
|
||||||
self._check_extension(psser, pser)
|
|
||||||
|
|
||||||
# boolean
|
|
||||||
pser = pd.Series([True, False, True, None], dtype="boolean")
|
|
||||||
psser = ps.from_pandas(pser)
|
|
||||||
|
|
||||||
self._check_extension(psser, pser)
|
|
||||||
self._check_extension(psser & psser, pser & pser)
|
|
||||||
self._check_extension(psser | psser, pser | pser)
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not extension_float_dtypes_available, "pandas extension float dtypes are not available"
|
|
||||||
)
|
|
||||||
def test_extension_float_dtypes(self):
|
|
||||||
for pser in [
|
|
||||||
pd.Series([1.0, 2.0, None, 4.0], dtype="Float32"),
|
|
||||||
pd.Series([1.0, 2.0, None, 4.0], dtype="Float64"),
|
|
||||||
]:
|
|
||||||
psser = ps.from_pandas(pser)
|
|
||||||
|
|
||||||
self._check_extension(psser, pser)
|
|
||||||
self._check_extension(psser + 1, pser + 1)
|
|
||||||
self._check_extension(psser + psser, pser + pser)
|
|
||||||
|
|
||||||
def test_empty_series(self):
|
def test_empty_series(self):
|
||||||
pser_a = pd.Series([], dtype="i1")
|
pser_a = pd.Series([], dtype="i1")
|
||||||
pser_b = pd.Series([], dtype="str")
|
pser_b = pd.Series([], dtype="str")
|
||||||
|
@ -872,18 +825,16 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
|
||||||
self.assert_eq(psser.nlargest(), pser.nlargest())
|
self.assert_eq(psser.nlargest(), pser.nlargest())
|
||||||
self.assert_eq((psser + 1).nlargest(), (pser + 1).nlargest())
|
self.assert_eq((psser + 1).nlargest(), (pser + 1).nlargest())
|
||||||
|
|
||||||
def test_isnull(self):
|
def test_notnull(self):
|
||||||
pser = pd.Series([1, 2, 3, 4, np.nan, 6], name="x")
|
pser = pd.Series([1, 2, 3, 4, np.nan, 6], name="x")
|
||||||
psser = ps.from_pandas(pser)
|
psser = ps.from_pandas(pser)
|
||||||
|
|
||||||
self.assert_eq(psser.notnull(), pser.notnull())
|
self.assert_eq(psser.notnull(), pser.notnull())
|
||||||
self.assert_eq(psser.isnull(), pser.isnull())
|
|
||||||
|
|
||||||
pser = self.pser
|
pser = self.pser
|
||||||
psser = self.psser
|
psser = self.psser
|
||||||
|
|
||||||
self.assert_eq(psser.notnull(), pser.notnull())
|
self.assert_eq(psser.notnull(), pser.notnull())
|
||||||
self.assert_eq(psser.isnull(), pser.isnull())
|
|
||||||
|
|
||||||
def test_all(self):
|
def test_all(self):
|
||||||
for pser in [
|
for pser in [
|
||||||
|
|
Loading…
Reference in a new issue