[SPARK-36505][PYTHON] Improve test coverage for frame.py
### What changes were proposed in this pull request? This PR proposes improving test coverage for pandas-on-Spark DataFrame code base, which is written in `frame.py`. This PR did the following to improve coverage: - Add unittest for untested code - Remove unused code - Add arguments to some functions for testing ### Why are the changes needed? To make the project healthier by improving coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unittest. Closes #33833 from itholic/SPARK-36505. Authored-by: itholic <haejoon.lee@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
1a42aa5bd4
commit
97e7d6e667
|
@ -841,12 +841,6 @@ class DataFrame(Frame, Generic[T]):
|
|||
def __radd__(self, other: Any) -> "DataFrame":
|
||||
return self._map_series_op("radd", other)
|
||||
|
||||
def __div__(self, other: Any) -> "DataFrame":
|
||||
return self._map_series_op("div", other)
|
||||
|
||||
def __rdiv__(self, other: Any) -> "DataFrame":
|
||||
return self._map_series_op("rdiv", other)
|
||||
|
||||
def __truediv__(self, other: Any) -> "DataFrame":
|
||||
return self._map_series_op("truediv", other)
|
||||
|
||||
|
@ -3185,7 +3179,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
)
|
||||
|
||||
def where(
|
||||
self, cond: DataFrameOrSeries, other: Union[DataFrameOrSeries, Any] = np.nan
|
||||
self,
|
||||
cond: DataFrameOrSeries,
|
||||
other: Union[DataFrameOrSeries, Any] = np.nan,
|
||||
axis: Axis = None,
|
||||
) -> "DataFrame":
|
||||
"""
|
||||
Replace values where the condition is False.
|
||||
|
@ -3197,6 +3194,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
replace with corresponding value from other.
|
||||
other : scalar, DataFrame
|
||||
Entries where cond is False are replaced with corresponding value from other.
|
||||
axis : int, default None
|
||||
Can only be set to 0 at the moment for compatibility with pandas.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -3295,6 +3294,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
"""
|
||||
from pyspark.pandas.series import Series
|
||||
|
||||
axis = validate_axis(axis)
|
||||
if axis != 0:
|
||||
raise NotImplementedError('axis should be either 0 or "index" currently.')
|
||||
|
||||
tmp_cond_col_name = "__tmp_cond_col_{}__".format
|
||||
tmp_other_col_name = "__tmp_other_col_{}__".format
|
||||
|
||||
|
@ -8744,10 +8747,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
index = labels
|
||||
elif axis == 1:
|
||||
columns = labels
|
||||
else:
|
||||
raise ValueError(
|
||||
"No axis named %s for object type %s." % (axis, type(axis).__name__)
|
||||
)
|
||||
|
||||
if index is not None and not is_list_like(index):
|
||||
raise TypeError(
|
||||
|
@ -9911,8 +9910,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
col = midx_col
|
||||
else:
|
||||
col = col | midx_col
|
||||
else:
|
||||
raise ValueError("Single or multi index must be specified.")
|
||||
return DataFrame(self._internal.with_filter(col))
|
||||
else:
|
||||
return self[items]
|
||||
|
@ -10098,11 +10095,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
)
|
||||
elif axis == 1:
|
||||
columns_mapper_fn, _, _ = gen_mapper_fn(mapper)
|
||||
else:
|
||||
raise ValueError(
|
||||
"argument axis should be either the axis name "
|
||||
"(‘index’, ‘columns’) or number (0, 1)"
|
||||
)
|
||||
else:
|
||||
if index:
|
||||
index_mapper_fn, index_mapper_ret_dtype, index_mapper_ret_stype = gen_mapper_fn(
|
||||
|
|
|
@ -217,6 +217,11 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
'"column" should be a scalar value or tuple that contains scalar values',
|
||||
lambda: psdf.insert(0, list("abc"), psser),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"loc must be int",
|
||||
lambda: psdf.insert((1,), "b", 10),
|
||||
)
|
||||
self.assertRaises(ValueError, lambda: psdf.insert(0, "e", [7, 8, 9, 10]))
|
||||
self.assertRaises(ValueError, lambda: psdf.insert(0, "f", ps.Series([7, 8])))
|
||||
self.assertRaises(AssertionError, lambda: psdf.insert(100, "y", psser))
|
||||
|
@ -424,6 +429,10 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
self.assert_eq(psdf, pdf)
|
||||
self.assert_eq(psser, pser)
|
||||
|
||||
pdf.columns = ["index", "b"]
|
||||
psdf.columns = ["index", "b"]
|
||||
self.assert_eq(psdf.reset_index(), pdf.reset_index())
|
||||
|
||||
def test_reset_index_with_default_index_types(self):
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=np.random.rand(3))
|
||||
psdf = ps.from_pandas(pdf)
|
||||
|
@ -676,6 +685,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
self.assert_eq(psdf.head(0), pdf.head(0))
|
||||
self.assert_eq(psdf.head(-3), pdf.head(-3))
|
||||
self.assert_eq(psdf.head(-10), pdf.head(-10))
|
||||
with option_context("compute.ordered_head", True):
|
||||
self.assert_eq(psdf.head(), pdf.head())
|
||||
|
||||
def test_attributes(self):
|
||||
psdf = self.psdf
|
||||
|
@ -835,6 +846,19 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
psdf5 = psdf3 + 1
|
||||
self.assert_eq(psdf5.rename(index=str_lower), pdf5.rename(index=str_lower))
|
||||
|
||||
msg = "Either `index` or `columns` should be provided."
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
psdf1.rename()
|
||||
msg = "`mapper` or `index` or `columns` should be either dict-like or function type."
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
psdf1.rename(mapper=[str_lower], axis=1)
|
||||
msg = "Mapper dict should have the same value type."
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
psdf1.rename({"A": "a", "B": 2}, axis=1)
|
||||
msg = r"level should be an integer between \[0, column_labels_level\)"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
psdf2.rename(columns=str_lower, level=2)
|
||||
|
||||
def test_rename_axis(self):
|
||||
index = pd.Index(["A", "B", "C"], name="index")
|
||||
columns = pd.Index(["numbers", "values"], name="cols")
|
||||
|
@ -1002,6 +1026,12 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
result = psdf.rename_axis(index=str.upper, columns=str.upper).sort_index()
|
||||
self.assert_eq(expected, result)
|
||||
|
||||
def test_dot(self):
|
||||
psdf = self.psdf
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Unsupported type DataFrame"):
|
||||
psdf.dot(psdf)
|
||||
|
||||
def test_dot_in_column_name(self):
|
||||
self.assert_eq(
|
||||
ps.DataFrame(ps.range(1)._internal.spark_frame.selectExpr("1L as `a.b`"))["a.b"],
|
||||
|
@ -1086,6 +1116,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
self.assert_eq(psdf.timestamp.min(), pdf.timestamp.min())
|
||||
self.assert_eq(psdf.timestamp.max(), pdf.timestamp.max())
|
||||
|
||||
self.assertRaises(ValueError, lambda: psdf.agg(("sum", "min")))
|
||||
|
||||
def test_droplevel(self):
|
||||
pdf = (
|
||||
pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
||||
|
@ -1283,8 +1315,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
psdf2 = psdf.copy()
|
||||
pser = pdf2[pdf2.columns[0]]
|
||||
psser = psdf2[psdf2.columns[0]]
|
||||
pdf2.dropna(inplace=True)
|
||||
psdf2.dropna(inplace=True)
|
||||
pdf2.dropna(inplace=True, axis=axis)
|
||||
psdf2.dropna(inplace=True, axis=axis)
|
||||
self.assert_eq(psdf2, pdf2)
|
||||
self.assert_eq(psser, pser)
|
||||
|
||||
|
@ -1362,6 +1394,12 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
|
||||
self._test_dropna(pdf, axis=1)
|
||||
|
||||
psdf = ps.from_pandas(pdf)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The length of each subset must be the same as the index size."
|
||||
):
|
||||
psdf.dropna(subset=(["x", "y"]), axis=1)
|
||||
|
||||
# empty
|
||||
pdf = pd.DataFrame({"x": [], "y": [], "z": []})
|
||||
psdf = ps.from_pandas(pdf)
|
||||
|
@ -1782,6 +1820,9 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
msg = r"'Key length \(4\) exceeds index depth \(3\)'"
|
||||
with self.assertRaisesRegex(KeyError, msg):
|
||||
psdf.xs(("mammal", "dog", "walks", "foo"))
|
||||
msg = "'key' should be a scalar value or tuple that contains scalar values"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
psdf.xs(["mammal", "dog", "walks", "foo"])
|
||||
|
||||
self.assertRaises(IndexError, lambda: psdf.xs("foo", level=-4))
|
||||
self.assertRaises(IndexError, lambda: psdf.xs("foo", level=3))
|
||||
|
@ -1935,6 +1976,7 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
|
||||
check(lambda left, right: left.merge(right))
|
||||
check(lambda left, right: left.merge(right, on="value"))
|
||||
check(lambda left, right: left.merge(right, on=("value",)))
|
||||
check(lambda left, right: left.merge(right, left_on="lkey", right_on="rkey"))
|
||||
check(lambda left, right: left.set_index("lkey").merge(right.set_index("rkey")))
|
||||
check(
|
||||
|
@ -2339,6 +2381,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
psdf = ps.from_pandas(pdf)
|
||||
|
||||
self.assert_eq(psdf + psdf.copy(), pdf + pdf.copy())
|
||||
self.assert_eq(psdf + psdf.loc[:, ["A", "B"]], pdf + pdf.loc[:, ["A", "B"]])
|
||||
self.assert_eq(psdf.loc[:, ["A", "B"]] + psdf, pdf.loc[:, ["A", "B"]] + pdf)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
|
@ -2352,6 +2396,14 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
lambda: ps.range(10).add(ps.range(10).id),
|
||||
)
|
||||
|
||||
psdf_other = psdf.copy()
|
||||
psdf_other.columns = pd.MultiIndex.from_tuples([("A", "Z"), ("B", "X"), ("C", "C")])
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"cannot join with no overlapping index names",
|
||||
lambda: psdf.add(psdf_other),
|
||||
)
|
||||
|
||||
def test_binary_operator_add(self):
|
||||
# Positive
|
||||
pdf = pd.DataFrame({"a": ["x"], "b": ["y"], "c": [1], "d": [2]})
|
||||
|
@ -2640,6 +2692,10 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
|
||||
with self.assertRaisesRegex(ValueError, "Length of to_replace and value must be same"):
|
||||
psdf.replace(to_replace=["Ironman"], value=["Spiderman", "Doctor Strange"])
|
||||
with self.assertRaisesRegex(TypeError, "Unsupported type function"):
|
||||
psdf.replace("Ironman", lambda x: "Spiderman")
|
||||
with self.assertRaisesRegex(TypeError, "Unsupported type function"):
|
||||
psdf.replace(lambda x: "Ironman", "Spiderman")
|
||||
|
||||
self.assert_eq(psdf.replace("Ironman", "Spiderman"), pdf.replace("Ironman", "Spiderman"))
|
||||
self.assert_eq(
|
||||
|
@ -3122,6 +3178,14 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
index=["c"], columns="A", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"}
|
||||
)
|
||||
|
||||
msg = "values should be one column or list of columns."
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
psdf.pivot_table(columns="a", values=(["b"], ["c"]))
|
||||
|
||||
msg = "aggfunc must be a dict mapping from column name to aggregate functions"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
psdf.pivot_table(columns="a", values="b", aggfunc={"a": lambda x: sum(x)})
|
||||
|
||||
psdf = ps.DataFrame(
|
||||
{
|
||||
"A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"],
|
||||
|
@ -3414,6 +3478,11 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
psdf.reindex(columns=["numbers"]).sort_index(),
|
||||
)
|
||||
|
||||
self.assert_eq(
|
||||
pdf.reindex(columns=["numbers"], copy=True).sort_index(),
|
||||
psdf.reindex(columns=["numbers"], copy=True).sort_index(),
|
||||
)
|
||||
|
||||
# Using float as fill_value to avoid int64/32 clash
|
||||
self.assert_eq(
|
||||
pdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(),
|
||||
|
@ -3464,6 +3533,7 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
|
||||
self.assertRaises(TypeError, lambda: psdf.reindex(columns=["numbers", "2", "3"], axis=1))
|
||||
self.assertRaises(TypeError, lambda: psdf.reindex(columns=["numbers", "2", "3"], axis=2))
|
||||
self.assertRaises(TypeError, lambda: psdf.reindex(columns="numbers"))
|
||||
self.assertRaises(TypeError, lambda: psdf.reindex(index=["A", "B", "C"], axis=1))
|
||||
self.assertRaises(TypeError, lambda: psdf.reindex(index=123))
|
||||
|
||||
|
@ -4552,6 +4622,10 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
psdf.quantile(q="a")
|
||||
with self.assertRaisesRegex(TypeError, "q must be a float or an array of floats;"):
|
||||
psdf.quantile(q=["a"])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"percentiles should all be in the interval \[0, 1\]"
|
||||
):
|
||||
psdf.quantile(q=[1.1])
|
||||
|
||||
self.assert_eq(
|
||||
psdf.quantile(0.5, numeric_only=False), pdf.quantile(0.5, numeric_only=False)
|
||||
|
@ -4596,7 +4670,13 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
self.assert_eq(psdf.pct_change().sum(), pdf.pct_change().sum(), check_exact=False)
|
||||
|
||||
def test_where(self):
|
||||
psdf = ps.from_pandas(self.pdf)
|
||||
pdf, psdf = self.df_pair
|
||||
|
||||
# pandas requires `axis` argument when the `other` is Series.
|
||||
# `axis` is not fully supported yet in pandas-on-Spark.
|
||||
self.assert_eq(
|
||||
psdf.where(psdf > 2, psdf.a + 10, axis=0), pdf.where(pdf > 2, pdf.a + 10, axis=0)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "type of cond must be a DataFrame or Series"):
|
||||
psdf.where(1)
|
||||
|
@ -5501,6 +5581,7 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
|
||||
self.assertRaises(ValueError, lambda: psdf1.align(psdf1, join="unknown"))
|
||||
self.assertRaises(ValueError, lambda: psdf1.align(psdf1["b"]))
|
||||
self.assertRaises(TypeError, lambda: psdf1.align(["b"]))
|
||||
self.assertRaises(NotImplementedError, lambda: psdf1.align(psdf1["b"], axis=1))
|
||||
|
||||
pdf2 = pd.DataFrame({"a": [4, 5, 6], "d": ["d", "e", "f"]}, index=[10, 11, 12])
|
||||
|
@ -5613,6 +5694,38 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
with self.assertRaisesRegex(TypeError, "Index must be DatetimeIndex"):
|
||||
psdf.at_time("0:15")
|
||||
|
||||
def test_astype(self):
|
||||
psdf = self.psdf
|
||||
|
||||
msg = "Only a column name can be used for the key in a dtype mappings argument."
|
||||
with self.assertRaisesRegex(KeyError, msg):
|
||||
psdf.astype({"c": float})
|
||||
|
||||
def test_describe(self):
|
||||
psdf = self.psdf
|
||||
|
||||
msg = r"Percentiles should all be in the interval \[0, 1\]"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
psdf.describe(percentiles=[1.1])
|
||||
|
||||
psdf = ps.DataFrame({"A": ["a", "b", "c"], "B": ["d", "e", "f"]})
|
||||
|
||||
msg = "Cannot describe a DataFrame without columns"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
psdf.describe()
|
||||
|
||||
def test_getitem_with_none_key(self):
|
||||
psdf = self.psdf
|
||||
|
||||
with self.assertRaisesRegex(KeyError, "none key"):
|
||||
psdf[None]
|
||||
|
||||
def test_iter_dataframe(self):
|
||||
pdf, psdf = self.df_pair
|
||||
|
||||
for value_psdf, value_pdf in zip(psdf, pdf):
|
||||
self.assert_eq(value_psdf, value_pdf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.pandas.tests.test_dataframe import * # noqa: F401
|
||||
|
|
Loading…
Reference in a new issue