From 50f7686de9fdf013e93f9598a2c12087916ddf07 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 7 Jun 2021 11:12:49 +0900 Subject: [PATCH] [SPARK-35599][PYTHON] Adjust `check_exact` parameter for older pd.testing ### What changes were proposed in this pull request? Adjust the `check_exact` parameter for non-numeric columns to ensure pandas-on-Spark tests passed with all pandas versions. ### Why are the changes needed? `pd.testing` utils are utilized in pandas-on-Spark tests. Due to https://github.com/pandas-dev/pandas/issues/35446, `check_exact=True` for non-numeric columns doesn't work for older pd.testing utils, e.g. `assert_series_equal`. We wanted to adjust that to ensure pandas-on-Spark tests pass for all pandas versions. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit tests. Closes #32772 from xinrong-databricks/test_util. Authored-by: Xinrong Meng Signed-off-by: Hyukjin Kwon --- python/pyspark/testing/pandasutils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 219447e9bf..e957a0f1fb 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -25,6 +25,7 @@ from distutils.version import LooseVersion import pandas as pd from pandas.api.types import is_list_like +from pandas.core.dtypes.common import is_numeric_dtype from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal from pyspark import pandas as ps @@ -81,6 +82,12 @@ class PandasOnSparkTestCase(unittest.TestCase, SQLTestUtils): else: kwargs = dict() + if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): + # Due to https://github.com/pandas-dev/pandas/issues/35446 + check_exact = check_exact \ + and all([is_numeric_dtype(dtype) for dtype in left.dtypes]) \ + and all([is_numeric_dtype(dtype) for dtype in right.dtypes]) + assert_frame_equal( left, right, @@ -102,7 +109,11 @@ class PandasOnSparkTestCase(unittest.TestCase, SQLTestUtils): kwargs = dict(check_freq=False) else: kwargs = dict() - + if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): + # Due to https://github.com/pandas-dev/pandas/issues/35446 + check_exact = check_exact \ + and is_numeric_dtype(left.dtype) \ + and is_numeric_dtype(right.dtype) assert_series_equal( left, right, @@ -119,6 +130,11 @@ class PandasOnSparkTestCase(unittest.TestCase, SQLTestUtils): raise AssertionError(msg) from e elif isinstance(left, pd.Index) and isinstance(right, pd.Index): try: + if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): + # Due to https://github.com/pandas-dev/pandas/issues/35446 + check_exact = check_exact \ + and is_numeric_dtype(left.dtype) \ + and is_numeric_dtype(right.dtype) assert_index_equal(left, right, check_exact=check_exact) except AssertionError as e: msg = (