[SPARK-35499][PYTHON] Apply black to pandas API on Spark codes

### What changes were proposed in this pull request?

This PR proposes applying `black` to pandas API on Spark codes, for improving static analysis.

By executing the `./dev/reformat-python` in the spark home directory, all the code of the pandas API on Spark is fixed according to the static analysis rules.

### Why are the changes needed?

This can be reduces the cost of static analysis during development.

It has been used continuously for about a year in the Koalas project and its convenience has been proven.

### Does this PR introduce _any_ user-facing change?

No, it's dev-only.

### How was this patch tested?

Manually reformat the pandas API on Spark codes by running the `./dev/reformat-python`, and checked the `./dev/lint-python` is passed.

Closes #32779 from itholic/SPARK-35499.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
itholic 2021-06-06 17:30:07 -07:00 committed by Liang-Chi Hsieh
parent d4e32c896a
commit b8740a1d1e
54 changed files with 428 additions and 228 deletions

View file

@ -366,7 +366,7 @@ jobs:
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
# Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375.
python3.6 -m pip install flake8 pydata_sphinx_theme mypy numpydoc 'jinja2<3.0.0'
python3.6 -m pip install flake8 pydata_sphinx_theme mypy numpydoc 'jinja2<3.0.0' 'black==21.5b2'
- name: Install R linter dependencies and SparkR
run: |
apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev

View file

@ -25,6 +25,8 @@ MINIMUM_PYCODESTYLE="2.7.0"
PYTHON_EXECUTABLE="python3"
BLACK_BUILD="$PYTHON_EXECUTABLE -m black"
function satisfies_min_version {
local provided_version="$1"
local expected_version="$2"
@ -185,6 +187,35 @@ flake8 checks failed."
fi
}
function black_test {
local BLACK_REPORT=
local BLACK_STATUS=
# Skip check if black is not installed.
$BLACK_BUILD 2> /dev/null
if [ $? -ne 0 ]; then
echo "The $BLACK_BUILD command was not found. Skipping black checks for now."
echo
return
fi
echo "starting black test..."
# Black is only applied for pandas API on Spark for now.
BLACK_REPORT=$( ($BLACK_BUILD python/pyspark/pandas --line-length 100 --check ) 2>&1)
BLACK_STATUS=$?
if [ "$BLACK_STATUS" -ne 0 ]; then
echo "black checks failed:"
echo "$BLACK_REPORT"
echo "Please run 'dev/reformat-python' script."
echo "$BLACK_STATUS"
exit "$BLACK_STATUS"
else
echo "black checks passed."
echo
fi
}
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
@ -194,6 +225,7 @@ pushd "$SPARK_ROOT_DIR" &> /dev/null
PYTHON_SOURCE="$(find . -path ./docs/.local_ruby_bundle -prune -false -o -name "*.py")"
compile_python_test "$PYTHON_SOURCE"
black_test
pycodestyle_test "$PYTHON_SOURCE"
flake8_test
mypy_test

32
dev/reformat-python Executable file
View file

@ -0,0 +1,32 @@
#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The current directory of the script.
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
FWDIR="$( cd "$DIR"/.. && pwd )"
cd "$FWDIR"
BLACK_BUILD="python -m black"
BLACK_VERSION="21.5b2"
$BLACK_BUILD 2> /dev/null
if [ $? -ne 0 ]; then
echo "The '$BLACK_BUILD' command was not found. Please install Black, for example, via 'pip install black==$BLACK_VERSION'."
exit 1
fi
# This script is only applied for pandas API on Spark for now.
$BLACK_BUILD python/pyspark/pandas --line-length 100

View file

@ -32,3 +32,6 @@ sphinx-plotly-directive
# Development scripts
jira
PyGithub
# pandas API on Spark Code formatter.
black

View file

@ -14,11 +14,13 @@
# limitations under the License.
[pycodestyle]
ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504
ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504
max-line-length=100
exclude=*/target/*,python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*
[flake8]
select = E901,E999,F821,F822,F823,F401,F405,B006
# Ignore F821 for plot documents in pandas API on Spark.
ignore = F821
exclude = python/docs/build/html/*,*/target/*,python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi
max-line-length = 100

View file

@ -1068,9 +1068,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
if isinstance(self, MultiIndex):
raise NotImplementedError("notna is not defined for MultiIndex")
return (~self.isnull()).rename(
self.name # type: ignore
)
return (~self.isnull()).rename(self.name) # type: ignore
notna = notnull

View file

@ -45,11 +45,7 @@ if TYPE_CHECKING:
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
def is_valid_operand_for_numeric_arithmetic(
operand: Any,
*,
allow_bool: bool = True
) -> bool:
def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool = True) -> bool:
"""Check whether the operand is valid for arithmetic operations against numerics."""
if isinstance(operand, numbers.Number) and not isinstance(operand, bool):
return True
@ -58,7 +54,8 @@ def is_valid_operand_for_numeric_arithmetic(
return False
else:
return isinstance(operand.spark.data_type, NumericType) or (
allow_bool and isinstance(operand.spark.data_type, BooleanType))
allow_bool and isinstance(operand.spark.data_type, BooleanType)
)
else:
return False

View file

@ -34,7 +34,7 @@ class BinaryOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'binaries'
return "binaries"
def add(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, BinaryType):
@ -43,11 +43,13 @@ class BinaryOps(DataTypeOps):
return column_op(F.concat)(left, F.lit(right))
else:
raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name)
"Concatenation can not be applied to %s and the given type." % self.pretty_name
)
def radd(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, bytes):
return left._with_new_scol(F.concat(F.lit(right), left.spark.column))
else:
raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name)
"Concatenation can not be applied to %s and the given type." % self.pretty_name
)

View file

@ -38,12 +38,13 @@ class BooleanOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'booleans'
return "booleans"
def add(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"Addition can not be applied to %s and the given type." % self.pretty_name)
"Addition can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
@ -56,7 +57,8 @@ class BooleanOps(DataTypeOps):
def sub(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"Subtraction can not be applied to %s and the given type." % self.pretty_name)
"Subtraction can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left - right
@ -68,7 +70,8 @@ class BooleanOps(DataTypeOps):
def mul(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"Multiplication can not be applied to %s and the given type." % self.pretty_name)
"Multiplication can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left * right
@ -80,7 +83,8 @@ class BooleanOps(DataTypeOps):
def truediv(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"True division can not be applied to %s and the given type." % self.pretty_name)
"True division can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left / right
@ -92,7 +96,8 @@ class BooleanOps(DataTypeOps):
def floordiv(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"Floor division can not be applied to %s and the given type." % self.pretty_name)
"Floor division can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left // right
@ -104,7 +109,8 @@ class BooleanOps(DataTypeOps):
def mod(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"Modulo can not be applied to %s and the given type." % self.pretty_name)
"Modulo can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left % right
@ -116,7 +122,8 @@ class BooleanOps(DataTypeOps):
def pow(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError(
"Exponentiation can not be applied to %s and the given type." % self.pretty_name)
"Exponentiation can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left ** right
@ -131,7 +138,8 @@ class BooleanOps(DataTypeOps):
return right + left
else:
raise TypeError(
"Addition can not be applied to %s and the given type." % self.pretty_name)
"Addition can not be applied to %s and the given type." % self.pretty_name
)
def rsub(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -139,7 +147,8 @@ class BooleanOps(DataTypeOps):
return right - left
else:
raise TypeError(
"Subtraction can not be applied to %s and the given type." % self.pretty_name)
"Subtraction can not be applied to %s and the given type." % self.pretty_name
)
def rmul(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -147,7 +156,8 @@ class BooleanOps(DataTypeOps):
return right * left
else:
raise TypeError(
"Multiplication can not be applied to %s and the given type." % self.pretty_name)
"Multiplication can not be applied to %s and the given type." % self.pretty_name
)
def rtruediv(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -155,7 +165,8 @@ class BooleanOps(DataTypeOps):
return right / left
else:
raise TypeError(
"True division can not be applied to %s and the given type." % self.pretty_name)
"True division can not be applied to %s and the given type." % self.pretty_name
)
def rfloordiv(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -163,7 +174,8 @@ class BooleanOps(DataTypeOps):
return right // left
else:
raise TypeError(
"Floor division can not be applied to %s and the given type." % self.pretty_name)
"Floor division can not be applied to %s and the given type." % self.pretty_name
)
def rpow(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -171,7 +183,8 @@ class BooleanOps(DataTypeOps):
return right ** left
else:
raise TypeError(
"Exponentiation can not be applied to %s and the given type." % self.pretty_name)
"Exponentiation can not be applied to %s and the given type." % self.pretty_name
)
def rmod(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -179,4 +192,5 @@ class BooleanOps(DataTypeOps):
return right % left
else:
raise TypeError(
"Modulo can not be applied to %s and the given type." % self.pretty_name)
"Modulo can not be applied to %s and the given type." % self.pretty_name
)

View file

@ -25,4 +25,4 @@ class CategoricalOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'categoricals'
return "categoricals"

View file

@ -34,22 +34,25 @@ class ArrayOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'arrays'
return "arrays"
def add(self, left, right) -> Union["Series", "Index"]:
if not isinstance(right, IndexOpsMixin) or (
isinstance(right, IndexOpsMixin) and not isinstance(right.spark.data_type, ArrayType)
):
raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name)
"Concatenation can not be applied to %s and the given type." % self.pretty_name
)
left_type = left.spark.data_type.elementType
right_type = right.spark.data_type.elementType
if left_type != right_type and not (
isinstance(left_type, NumericType) and isinstance(right_type, NumericType)):
isinstance(left_type, NumericType) and isinstance(right_type, NumericType)
):
raise TypeError(
"Concatenation can only be applied to %s of the same type" % self.pretty_name)
"Concatenation can only be applied to %s of the same type" % self.pretty_name
)
return column_op(F.concat)(left, right)
@ -61,7 +64,7 @@ class MapOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'maps'
return "maps"
class StructOps(DataTypeOps):
@ -71,4 +74,4 @@ class StructOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'structs'
return "structs"

View file

@ -37,7 +37,7 @@ class DateOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'dates'
return "dates"
def sub(self, left, right) -> Union["Series", "Index"]:
# Note that date subtraction casts arguments to integer. This is to mimic pandas's

View file

@ -38,7 +38,7 @@ class DatetimeOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'datetimes'
return "datetimes"
def sub(self, left, right) -> Union["Series", "Index"]:
# Note that timestamp subtraction casts arguments to integer. This is to mimic pandas's

View file

@ -46,7 +46,7 @@ class NumericOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'numerics'
return "numerics"
def add(self, left, right) -> Union["Series", "Index"]:
if (
@ -159,7 +159,7 @@ class IntegralOps(NumericOps):
@property
def pretty_name(self) -> str:
return 'integrals'
return "integrals"
def mul(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, str):
@ -211,9 +211,7 @@ class IntegralOps(NumericOps):
return F.when(F.lit(right is np.nan), np.nan).otherwise(
F.when(
F.lit(right != 0) | F.lit(right).isNull(), F.floor(left.__div__(right))
).otherwise(
F.lit(np.inf).__div__(left)
)
).otherwise(F.lit(np.inf).__div__(left))
)
return numpy_column_op(floordiv)(left, right)
@ -253,7 +251,7 @@ class FractionalOps(NumericOps):
@property
def pretty_name(self) -> str:
return 'fractions'
return "fractions"
def mul(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, str):

View file

@ -38,7 +38,7 @@ class StringOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return 'strings'
return "strings"
def add(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, StringType):

View file

@ -124,7 +124,9 @@ def _register_accessor(
)
warnings.warn(
msg, UserWarning, stacklevel=2,
msg,
UserWarning,
stacklevel=2,
)
setattr(cls, name, CachedAccessor(name, accessor))
return accessor

View file

@ -5445,7 +5445,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return psdf
def replace(
self, to_replace=None, value=None, inplace=False, limit=None, regex=False, method="pad",
self,
to_replace=None,
value=None,
inplace=False,
limit=None,
regex=False,
method="pad",
) -> Optional["DataFrame"]:
"""
Returns a new DataFrame replacing a value with another value.
@ -7110,7 +7116,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self._internal.index_dtypes,
)
)
index_map[i], index_map[j], = index_map[j], index_map[i]
index_map[i], index_map[j], = (
index_map[j],
index_map[i],
)
index_spark_columns, index_names, index_dtypes = zip(*index_map)
internal = self._internal.copy(
index_spark_columns=list(index_spark_columns),

View file

@ -1806,7 +1806,9 @@ class GroupBy(object, metaclass=ABCMeta):
]
psdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply(
psdf, self._groupkeys, agg_columns,
psdf,
self._groupkeys,
agg_columns,
)
groupkey_scols = [psdf._internal.spark_column_for(label) for label in groupkey_labels]

View file

@ -447,7 +447,10 @@ class MultiIndex(Index):
self._internal.index_dtypes,
)
)
index_map[i], index_map[j], = index_map[j], index_map[i]
index_map[i], index_map[j], = (
index_map[j],
index_map[i],
)
index_spark_columns, index_names, index_dtypes = zip(*index_map)
internal = self._internal.copy(
index_spark_columns=list(index_spark_columns),
@ -684,9 +687,7 @@ class MultiIndex(Index):
raise NotImplementedError("nunique is not defined for MultiIndex")
# TODO: add 'name' parameter after pd.MultiIndex.name is implemented
def copy( # type: ignore[override]
self, deep: Optional[bool] = None
) -> "MultiIndex":
def copy(self, deep: Optional[bool] = None) -> "MultiIndex": # type: ignore[override]
"""
Make a copy of this object.

View file

@ -879,10 +879,7 @@ class InternalFrame(object):
return index_spark_columns + [
spark_column
for spark_column in self.data_spark_columns
if all(
not spark_column_equals(spark_column, scol)
for scol in index_spark_columns
)
if all(not spark_column_equals(spark_column, scol) for scol in index_spark_columns)
]
@property
@ -929,10 +926,7 @@ class InternalFrame(object):
index_spark_columns = self.index_spark_columns
data_columns = []
for spark_column in self.data_spark_columns:
if all(
not spark_column_equals(spark_column, scol)
for scol in index_spark_columns
):
if all(not spark_column_equals(spark_column, scol) for scol in index_spark_columns):
data_columns.append(spark_column)
return self.spark_frame.select(index_spark_columns + data_columns)
@ -1055,7 +1049,8 @@ class InternalFrame(object):
pdf.columns = pd.MultiIndex.from_tuples(column_labels, names=names)
else:
pdf.columns = pd.Index(
[None if label is None else label[0] for label in column_labels], name=names[0],
[None if label is None else label[0] for label in column_labels],
name=names[0],
)
return pdf
@ -1354,7 +1349,9 @@ class InternalFrame(object):
schema = StructType(
[
StructField(
name, infer_pd_series_spark_type(col, dtype), nullable=bool(col.isnull().any()),
name,
infer_pd_series_spark_type(col, dtype),
nullable=bool(col.isnull().any()),
)
for (name, col), dtype in zip(pdf.iteritems(), index_dtypes + data_dtypes)
]

View file

@ -2212,11 +2212,19 @@ def concat(objs, axis=0, join="outer", ignore_index=False, sort=False) -> Union[
for psdf in psdfs_not_same_anchor:
if join == "inner":
concat_psdf = align_diff_frames(
resolve_func, concat_psdf, psdf, fillna=False, how="inner",
resolve_func,
concat_psdf,
psdf,
fillna=False,
how="inner",
)
elif join == "outer":
concat_psdf = align_diff_frames(
resolve_func, concat_psdf, psdf, fillna=False, how="full",
resolve_func,
concat_psdf,
psdf,
fillna=False,
how="full",
)
concat_psdf = concat_psdf[column_labels]

View file

@ -374,11 +374,19 @@ class KdePlotBase:
if ind is None:
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
ind = np.linspace(min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, 1000,)
ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
1000,
)
elif is_integer(ind):
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
ind = np.linspace(min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, ind,)
ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
ind,
)
return ind
@staticmethod

View file

@ -1592,7 +1592,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
else:
return first_series(psdf)
def reindex(self, index: Optional[Any] = None, fill_value: Optional[Any] = None,) -> "Series":
def reindex(
self,
index: Optional[Any] = None,
fill_value: Optional[Any] = None,
) -> "Series":
"""
Conform Series to new index with optional filling logic, placing
NA/NaN in locations having no value in the previous index. A new object
@ -3485,7 +3489,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if method == "first":
window = (
Window.orderBy(
asc_func(self.spark.column), asc_func(F.col(NATURAL_ORDER_COLUMN_NAME)),
asc_func(self.spark.column),
asc_func(F.col(NATURAL_ORDER_COLUMN_NAME)),
)
.partitionBy(*part_cols)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)

View file

@ -26,7 +26,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestCase
class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property
def pser(self):
return pd.Series([b'1', b'2', b'3'])
return pd.Series([b"1", b"2", b"3"])
@property
def psser(self):
@ -35,9 +35,9 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def test_add(self):
psser = self.psser
pser = self.pser
self.assert_eq(psser + b'1', pser + b'1')
self.assert_eq(psser + b"1", pser + b"1")
self.assert_eq(psser + psser, pser + pser)
self.assert_eq(psser + psser.astype('bytes'), pser + pser.astype('bytes'))
self.assert_eq(psser + psser.astype("bytes"), pser + pser.astype("bytes"))
self.assertRaises(TypeError, lambda: psser + "x")
self.assertRaises(TypeError, lambda: psser + 1)
@ -95,7 +95,7 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: self.psser ** psser)
def test_radd(self):
self.assert_eq(b'1' + self.psser, b'1' + self.pser)
self.assert_eq(b"1" + self.psser, b"1" + self.pser)
self.assertRaises(TypeError, lambda: "x" + self.psser)
self.assertRaises(TypeError, lambda: 1 + self.psser)
@ -129,7 +129,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -103,7 +103,8 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils):
with option_context("compute.ops_on_diff_frames", True):
self.assert_eq(
self.pser / self.float_pser, (self.psser / self.float_psser).sort_index())
self.pser / self.float_pser, (self.psser / self.float_psser).sort_index()
)
for psser in self.non_numeric_pssers.values():
self.assertRaises(TypeError, lambda: self.psser / psser)
@ -235,7 +236,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -122,7 +122,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -32,16 +32,17 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
return [
pd.Series([[1, 2, 3]]),
pd.Series([[0.1, 0.2, 0.3]]),
pd.Series([[decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3)]])
pd.Series([[decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3)]]),
]
@property
def non_numeric_array_psers(self):
return {
"string": pd.Series([['x', 'y', 'z']]),
"date": pd.Series([
[datetime.date(1994, 1, 1), datetime.date(1994, 1, 2), datetime.date(1994, 1, 3)]]),
"bool": pd.Series([[True, True, False]])
"string": pd.Series([["x", "y", "z"]]),
"date": pd.Series(
[[datetime.date(1994, 1, 1), datetime.date(1994, 1, 2), datetime.date(1994, 1, 3)]]
),
"bool": pd.Series([[True, True, False]]),
}
@property
@ -80,16 +81,19 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
# Non-numeric array + Non-numeric array
self.assertRaises(
TypeError, lambda:
self.non_numeric_array_pssers['string'] + self.non_numeric_array_pssers['bool']
TypeError,
lambda: self.non_numeric_array_pssers["string"]
+ self.non_numeric_array_pssers["bool"],
)
self.assertRaises(
TypeError, lambda:
self.non_numeric_array_pssers['string'] + self.non_numeric_array_pssers['date']
TypeError,
lambda: self.non_numeric_array_pssers["string"]
+ self.non_numeric_array_pssers["date"],
)
self.assertRaises(
TypeError, lambda:
self.non_numeric_array_pssers['bool'] + self.non_numeric_array_pssers['date']
TypeError,
lambda: self.non_numeric_array_pssers["bool"]
+ self.non_numeric_array_pssers["date"],
)
for data_type in self.non_numeric_array_psers.keys():
@ -97,7 +101,7 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.non_numeric_array_psers.get(data_type)
+ self.non_numeric_array_psers.get(data_type),
self.non_numeric_array_pssers.get(data_type)
+ self.non_numeric_array_pssers.get(data_type)
+ self.non_numeric_array_pssers.get(data_type),
)
# Numeric array + Non-numeric array
@ -193,7 +197,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -55,7 +55,8 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: self.psser - "x")
self.assertRaises(TypeError, lambda: self.psser - 1)
self.assert_eq(
(self.pser - self.some_date).dt.days, self.psser - self.some_date,
(self.pser - self.some_date).dt.days,
self.psser - self.some_date,
)
with option_context("compute.ops_on_diff_frames", True):
for pser, psser in self.pser_psser_pairs:
@ -118,7 +119,8 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: "x" - self.psser)
self.assertRaises(TypeError, lambda: 1 - self.psser)
self.assert_eq(
(self.some_date - self.pser).dt.days, self.some_date - self.psser,
(self.some_date - self.pser).dt.days,
self.some_date - self.psser,
)
def test_rmul(self):
@ -152,7 +154,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -154,7 +154,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -34,6 +34,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
returns float32.
The underlying reason is the respective Spark operations return DoubleType always.
"""
@property
def float_pser(self):
return pd.Series([1, 2, 3], dtype=float)
@ -137,7 +138,8 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: psser // self.non_numeric_pssers["datetime"])
self.assertRaises(TypeError, lambda: psser // self.non_numeric_pssers["date"])
self.assertRaises(
TypeError, lambda: psser // self.non_numeric_pssers["categorical"])
TypeError, lambda: psser // self.non_numeric_pssers["categorical"]
)
if LooseVersion(pd.__version__) >= LooseVersion("0.25.3"):
self.assert_eq(
(self.float_psser // self.non_numeric_pssers["bool"]).sort_index(),
@ -146,7 +148,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
else:
self.assert_eq(
(self.float_pser // self.non_numeric_psers["bool"]).sort_index(),
ps.Series([1.0, 2.0, np.inf])
ps.Series([1.0, 2.0, np.inf]),
)
self.assertRaises(TypeError, lambda: self.float_psser // True)
@ -181,7 +183,8 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: psser ** self.non_numeric_pssers["datetime"])
self.assertRaises(TypeError, lambda: psser ** self.non_numeric_pssers["date"])
self.assertRaises(
TypeError, lambda: psser ** self.non_numeric_pssers["categorical"])
TypeError, lambda: psser ** self.non_numeric_pssers["categorical"]
)
self.assert_eq(
(self.float_psser ** self.non_numeric_pssers["bool"]).sort_index(),
self.float_pser ** self.non_numeric_psers["bool"],
@ -258,7 +261,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -45,7 +45,8 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["datetime"])
self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["date"])
self.assertRaises(
TypeError, lambda: self.psser + self.non_numeric_pssers["categorical"])
TypeError, lambda: self.psser + self.non_numeric_pssers["categorical"]
)
self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["bool"])
for psser in self.numeric_pssers:
self.assertRaises(TypeError, lambda: self.psser + psser)
@ -135,7 +136,8 @@ if __name__ == "__main__":
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -26,6 +26,7 @@ import pyspark.pandas as ps
class TestCasesUtils(object):
"""A utility holding common test cases for arithmetic operations of different data types."""
@property
def numeric_psers(self):
dtypes = [np.float32, float, int, np.int32]

View file

@ -174,7 +174,8 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
# The `name` argument is added in pandas 0.24.
self.assert_eq(psidx.to_frame(name="x"), pidx.to_frame(name="x"))
self.assert_eq(
psidx.to_frame(index=False, name="x"), pidx.to_frame(index=False, name="x"),
psidx.to_frame(index=False, name="x"),
pidx.to_frame(index=False, name="x"),
)
self.assertRaises(TypeError, lambda: psidx.to_frame(name=["x"]))

View file

@ -169,7 +169,8 @@ class DataFramePlotPlotlyTest(PandasOnSparkTestCase, TestUtils):
)
self.assertEqual(
psdf.plot(kind="pie", values="a"), express.pie(pdf, values="a"),
psdf.plot(kind="pie", values="a"),
express.pie(pdf, values="a"),
)
psdf1 = self.psdf1

View file

@ -118,7 +118,8 @@ class SeriesPlotPlotlyTest(PandasOnSparkTestCase, TestUtils):
psdf = self.psdf1
pdf = psdf.to_pandas()
self.assertEqual(
psdf["a"].plot(kind="pie"), express.pie(pdf, values=pdf.columns[0], names=pdf.index),
psdf["a"].plot(kind="pie"),
express.pie(pdf, values=pdf.columns[0], names=pdf.index),
)
# TODO: support multi-index columns

View file

@ -387,14 +387,16 @@ class CategoricalTest(PandasOnSparkTestCase, TestUtils):
return pdf.astype(str)
self.assert_eq(
psdf.pandas_on_spark.transform_batch(to_str).sort_index(), to_str(pdf).sort_index(),
psdf.pandas_on_spark.transform_batch(to_str).sort_index(),
to_str(pdf).sort_index(),
)
def to_codes(pdf) -> ps.Series[np.int8]:
return pdf.b.cat.codes
self.assert_eq(
psdf.pandas_on_spark.transform_batch(to_codes).sort_index(), to_codes(pdf).sort_index(),
psdf.pandas_on_spark.transform_batch(to_codes).sort_index(),
to_codes(pdf).sort_index(),
)
pdf = pd.DataFrame(

View file

@ -1518,7 +1518,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
# Assert approximate counts
self.assert_eq(
ps.DataFrame({"A": range(100)}).nunique(approx=True), pd.Series([103], index=["A"]),
ps.DataFrame({"A": range(100)}).nunique(approx=True),
pd.Series([103], index=["A"]),
)
self.assert_eq(
ps.DataFrame({"A": range(100)}).nunique(approx=True, rsd=0.01),
@ -3354,7 +3355,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
columns2 = pd.Index(["numbers", "2", "3"], name="cols2")
self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(),
pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
)
columns = pd.Index(["numbers"], name="cols")
@ -3398,12 +3400,14 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
columns2 = pd.Index(["numbers", "2", "3"])
self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(),
pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
)
columns2 = pd.Index(["numbers", "2", "3"], name="cols2")
self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(),
pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
)
# Reindexing single Index on single Index
@ -3506,7 +3510,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
[("X", "numbers"), ("Y", "2"), ("Y", "3")], names=["cols3", "cols4"]
)
self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(),
pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
)
self.assertRaises(TypeError, lambda: psdf.reindex(columns=["X"]))
@ -3527,7 +3532,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf2 = ps.from_pandas(pdf2)
self.assert_eq(
pdf.reindex_like(pdf2).sort_index(), psdf.reindex_like(psdf2).sort_index(),
pdf.reindex_like(pdf2).sort_index(),
psdf.reindex_like(psdf2).sort_index(),
)
pdf2 = pd.DataFrame({"index_level_1": ["A", "C", "I"]})
@ -3546,7 +3552,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf2 = ps.from_pandas(pdf2)
self.assert_eq(
pdf.reindex_like(pdf2).sort_index(), psdf.reindex_like(psdf2).sort_index(),
pdf.reindex_like(pdf2).sort_index(),
psdf.reindex_like(psdf2).sort_index(),
)
self.assertRaises(TypeError, lambda: psdf.reindex_like(index2))
@ -3569,7 +3576,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf = ps.from_pandas(pdf)
self.assert_eq(
pdf.reindex_like(pdf2).sort_index(), psdf.reindex_like(psdf2).sort_index(),
pdf.reindex_like(pdf2).sort_index(),
psdf.reindex_like(psdf2).sort_index(),
)
def test_melt(self):
@ -3953,16 +3961,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pdf.duplicated().sort_index(), psdf.duplicated().sort_index())
self.assert_eq(
pdf.duplicated(keep="last").sort_index(), psdf.duplicated(keep="last").sort_index(),
pdf.duplicated(keep="last").sort_index(),
psdf.duplicated(keep="last").sort_index(),
)
self.assert_eq(
pdf.duplicated(keep=False).sort_index(), psdf.duplicated(keep=False).sort_index(),
pdf.duplicated(keep=False).sort_index(),
psdf.duplicated(keep=False).sort_index(),
)
self.assert_eq(
pdf.duplicated(subset="b").sort_index(), psdf.duplicated(subset="b").sort_index(),
pdf.duplicated(subset="b").sort_index(),
psdf.duplicated(subset="b").sort_index(),
)
self.assert_eq(
pdf.duplicated(subset=["b"]).sort_index(), psdf.duplicated(subset=["b"]).sort_index(),
pdf.duplicated(subset=["b"]).sort_index(),
psdf.duplicated(subset=["b"]).sort_index(),
)
with self.assertRaisesRegex(ValueError, "'keep' only supports 'first', 'last' and False"):
psdf.duplicated(keep="false")
@ -4009,7 +4021,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pdf.duplicated().sort_index(), psdf.duplicated().sort_index())
self.assert_eq(
pdf.duplicated(subset=10).sort_index(), psdf.duplicated(subset=10).sort_index(),
pdf.duplicated(subset=10).sort_index(),
psdf.duplicated(subset=10).sort_index(),
)
def test_ffill(self):
@ -4651,17 +4664,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index()
)
self.assert_eq(
psdf.take(range(1, 3), axis=1).sort_index(), pdf.take(range(1, 3), axis=1).sort_index(),
psdf.take(range(1, 3), axis=1).sort_index(),
pdf.take(range(1, 3), axis=1).sort_index(),
)
self.assert_eq(
psdf.take(range(-1, -3), axis=1).sort_index(),
pdf.take(range(-1, -3), axis=1).sort_index(),
)
self.assert_eq(
psdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(),
psdf.take([2, 1], axis=1).sort_index(),
pdf.take([2, 1], axis=1).sort_index(),
)
self.assert_eq(
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index(),
psdf.take([-1, -2], axis=1).sort_index(),
pdf.take([-1, -2], axis=1).sort_index(),
)
# MultiIndex columns
@ -4695,17 +4711,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index()
)
self.assert_eq(
psdf.take(range(1, 3), axis=1).sort_index(), pdf.take(range(1, 3), axis=1).sort_index(),
psdf.take(range(1, 3), axis=1).sort_index(),
pdf.take(range(1, 3), axis=1).sort_index(),
)
self.assert_eq(
psdf.take(range(-1, -3), axis=1).sort_index(),
pdf.take(range(-1, -3), axis=1).sort_index(),
)
self.assert_eq(
psdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(),
psdf.take([2, 1], axis=1).sort_index(),
pdf.take([2, 1], axis=1).sort_index(),
)
self.assert_eq(
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index(),
psdf.take([-1, -2], axis=1).sort_index(),
pdf.take([-1, -2], axis=1).sort_index(),
)
# Checking the type of indices.
@ -5524,35 +5543,40 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf = ps.from_pandas(pdf)
psdf.at_time("0:20")
self.assert_eq(
pdf.at_time("0:20").sort_index(), psdf.at_time("0:20").sort_index(),
pdf.at_time("0:20").sort_index(),
psdf.at_time("0:20").sort_index(),
)
# Index name is 'ts'
pdf.index.name = "ts"
psdf = ps.from_pandas(pdf)
self.assert_eq(
pdf.at_time("0:20").sort_index(), psdf.at_time("0:20").sort_index(),
pdf.at_time("0:20").sort_index(),
psdf.at_time("0:20").sort_index(),
)
# Index name is 'ts', column label is 'index'
pdf.columns = pd.Index(["index"])
psdf = ps.from_pandas(pdf)
self.assert_eq(
pdf.at_time("0:40").sort_index(), psdf.at_time("0:40").sort_index(),
pdf.at_time("0:40").sort_index(),
psdf.at_time("0:40").sort_index(),
)
# Both index name and column label are 'index'
pdf.index.name = "index"
psdf = ps.from_pandas(pdf)
self.assert_eq(
pdf.at_time("0:40").sort_index(), psdf.at_time("0:40").sort_index(),
pdf.at_time("0:40").sort_index(),
psdf.at_time("0:40").sort_index(),
)
# Index name is 'index', column label is ('X', 'A')
pdf.columns = pd.MultiIndex.from_arrays([["X"], ["A"]])
psdf = ps.from_pandas(pdf)
self.assert_eq(
pdf.at_time("0:40").sort_index(), psdf.at_time("0:40").sort_index(),
pdf.at_time("0:40").sort_index(),
psdf.at_time("0:40").sort_index(),
)
with self.assertRaisesRegex(NotImplementedError, "'asof' argument is not supported"):

View file

@ -310,7 +310,8 @@ class DataFrameSparkIOTest(PandasOnSparkTestCase, TestUtils):
self.assert_eq(psdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
self.assert_eq(
ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"), pdfs1["Sheet_name_2"],
ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"),
pdfs1["Sheet_name_2"],
)
for sheet_name in sheet_names:

View file

@ -96,7 +96,8 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
psdf = ps.DataFrame({"a": [1, 2, 3, 2], "b": [4.0, 2.0, 3.0, 1.0]}, index=idx)
psdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
expected_result = pd.DataFrame(
{("a", "x"): [None, 2.0, 3.0, 4.0], ("a", "y"): [None, 2.0, 3.0, 4.0]}, index=idx,
{("a", "x"): [None, 2.0, 3.0, 4.0], ("a", "y"): [None, 2.0, 3.0, 4.0]},
index=idx,
)
self.assert_eq(psdf.expanding(2).count().sort_index(), expected_result.sort_index())

View file

@ -934,7 +934,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
expected = ps.DataFrame({"b": [2, 2]}, index=pd.Index([0, 1], name="a"))
self.assert_eq(psdf.groupby("a").nunique().sort_index(), expected)
self.assert_eq(
psdf.groupby("a").nunique(dropna=False).sort_index(), expected,
psdf.groupby("a").nunique(dropna=False).sort_index(),
expected,
)
else:
self.assert_eq(
@ -968,10 +969,12 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
if LooseVersion(pd.__version__) < LooseVersion("1.1.0"):
expected = ps.DataFrame({("y", "b"): [2, 2]}, index=pd.Index([0, 1], name=("x", "a")))
self.assert_eq(
psdf.groupby(("x", "a")).nunique().sort_index(), expected,
psdf.groupby(("x", "a")).nunique().sort_index(),
expected,
)
self.assert_eq(
psdf.groupby(("x", "a")).nunique(dropna=False).sort_index(), expected,
psdf.groupby(("x", "a")).nunique(dropna=False).sort_index(),
expected,
)
else:
self.assert_eq(
@ -1785,7 +1788,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
pdf.groupby("A")[["B"]].bfill().sort_index(),
)
self.assert_eq(
psdf.groupby("A")["B"].bfill().sort_index(), pdf.groupby("A")["B"].bfill().sort_index(),
psdf.groupby("A")["B"].bfill().sort_index(),
pdf.groupby("A")["B"].bfill().sort_index(),
)
self.assert_eq(
psdf.groupby("A")["B"].bfill()[idx[6]], pdf.groupby("A")["B"].bfill()[idx[6]]
@ -1893,7 +1897,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
pdf.groupby("b").apply(lambda x: x + x.min()).sort_index(),
)
self.assert_eq(
psdf.groupby("b").apply(len).sort_index(), pdf.groupby("b").apply(len).sort_index(),
psdf.groupby("b").apply(len).sort_index(),
pdf.groupby("b").apply(len).sort_index(),
)
self.assert_eq(
psdf.groupby("b")["a"]
@ -2556,7 +2561,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
psdf = ps.from_pandas(pdf)
self.assert_eq(
psdf.groupby("class").get_group("bird"), pdf.groupby("class").get_group("bird"),
psdf.groupby("class").get_group("bird"),
pdf.groupby("class").get_group("bird"),
)
self.assert_eq(
psdf.groupby("class")["name"].get_group("mammal"),
@ -2583,7 +2589,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
(pdf.max_speed + 1).groupby(pdf["class"]).get_group("mammal"),
)
self.assert_eq(
psdf.groupby("max_speed").get_group(80.5), pdf.groupby("max_speed").get_group(80.5),
psdf.groupby("max_speed").get_group(80.5),
pdf.groupby("max_speed").get_group(80.5),
)
self.assertRaises(KeyError, lambda: psdf.groupby("class").get_group("fish"))
@ -2646,7 +2653,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion", "mammal")),
)
self.assertRaises(
ValueError, lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion",)),
ValueError,
lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion",)),
)
self.assertRaises(
ValueError, lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("mammal",))

View file

@ -266,7 +266,10 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
objs = [
([psdf1.A, psdf1.A.rename("B")], [pdf1.A, pdf1.A.rename("B")]),
([psdf3[("X", "A")], psdf3[("X", "B")]], [pdf3[("X", "A")], pdf3[("X", "B")]],),
(
[psdf3[("X", "A")], psdf3[("X", "B")]],
[pdf3[("X", "A")], pdf3[("X", "B")]],
),
(
[psdf3[("X", "A")], psdf3[("X", "B")].rename("ABC")],
[pdf3[("X", "A")], pdf3[("X", "B")].rename("ABC")],

View file

@ -728,7 +728,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
psser1 = ps.from_pandas(pser1)
psser2 = ps.from_pandas(pser2)
self.assert_eq(
pser1.compare(pser2).sort_index(), psser1.compare(psser2).sort_index(),
pser1.compare(pser2).sort_index(),
psser1.compare(psser2).sort_index(),
)
# `keep_shape=True`
@ -757,7 +758,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
psser1 = ps.from_pandas(pser1)
psser2 = ps.from_pandas(pser2)
self.assert_eq(
pser1.compare(pser2).sort_index(), psser1.compare(psser2).sort_index(),
pser1.compare(pser2).sort_index(),
psser1.compare(psser2).sort_index(),
)
# `keep_shape=True` with MultiIndex
@ -790,14 +792,16 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"],
)
self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True).sort_index(),
expected,
psser1.compare(psser2, keep_shape=True).sort_index(),
)
# `keep_equal=True`
expected = ps.DataFrame(
[["b", "a"], ["g", None], [None, "h"]], index=[0, 3, 4], columns=["self", "other"]
)
self.assert_eq(
expected, psser1.compare(psser2, keep_equal=True).sort_index(),
expected,
psser1.compare(psser2, keep_equal=True).sort_index(),
)
# `keep_shape=True` and `keep_equal=True`
expected = ps.DataFrame(
@ -806,7 +810,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"],
)
self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(),
expected,
psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(),
)
# MultiIndex
@ -838,7 +843,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"],
)
self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True).sort_index(),
expected,
psser1.compare(psser2, keep_shape=True).sort_index(),
)
# `keep_equal=True`
expected = ps.DataFrame(
@ -847,7 +853,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"],
)
self.assert_eq(
expected, psser1.compare(psser2, keep_equal=True).sort_index(),
expected,
psser1.compare(psser2, keep_equal=True).sort_index(),
)
# `keep_shape=True` and `keep_equal=True`
expected = ps.DataFrame(
@ -858,15 +865,22 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"],
)
self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(),
expected,
psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(),
)
# Different Index
with self.assertRaisesRegex(
ValueError, "Can only compare identically-labeled Series objects"
):
psser1 = ps.Series([1, 2, 3, 4, 5], index=pd.Index([1, 2, 3, 4, 5]),)
psser2 = ps.Series([2, 2, 3, 4, 1], index=pd.Index([5, 4, 3, 2, 1]),)
psser1 = ps.Series(
[1, 2, 3, 4, 5],
index=pd.Index([1, 2, 3, 4, 5]),
)
psser2 = ps.Series(
[2, 2, 3, 4, 1],
index=pd.Index([5, 4, 3, 2, 1]),
)
psser1.compare(psser2)
# Different MultiIndex
with self.assertRaisesRegex(
@ -1153,7 +1167,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(psdf, pdf)
with self.assertRaisesRegex(
ValueError, "shape mismatch",
ValueError,
"shape mismatch",
):
psdf.iloc[[1, 2], [1]] = -another_psdf.max_speed

View file

@ -135,7 +135,8 @@ class OpsOnDiffFramesGroupByTest(PandasOnSparkTestCase, SQLTestUtils):
)
self.assert_eq(
psdf1.B.groupby(psdf2.A).sum().sort_index(), pdf1.B.groupby(pdf2.A).sum().sort_index(),
psdf1.B.groupby(psdf2.A).sum().sort_index(),
pdf1.B.groupby(pdf2.A).sum().sort_index(),
)
self.assert_eq(
(psdf1.B + 1).groupby(psdf2.A).sum().sort_index(),

View file

@ -264,7 +264,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser = ps.from_pandas(pser)
self.assert_eq(
pser.rename_axis("index2").sort_index(), psser.rename_axis("index2").sort_index(),
pser.rename_axis("index2").sort_index(),
psser.rename_axis("index2").sort_index(),
)
self.assert_eq(
@ -463,7 +464,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pser, psser)
self.assert_eq(
pser.reindex(["A", "B"]).sort_index(), psser.reindex(["A", "B"]).sort_index(),
pser.reindex(["A", "B"]).sort_index(),
psser.reindex(["A", "B"]).sort_index(),
)
self.assert_eq(
@ -491,7 +493,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser2 = ps.from_pandas(pser2)
self.assert_eq(
pser.reindex_like(pser2).sort_index(), psser.reindex_like(psser2).sort_index(),
pser.reindex_like(pser2).sort_index(),
psser.reindex_like(psser2).sort_index(),
)
self.assert_eq(
@ -507,7 +510,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser2 = ps.from_pandas(pser2)
self.assert_eq(
pser.reindex_like(pser2).sort_index(), psser.reindex_like(psser2).sort_index(),
pser.reindex_like(pser2).sort_index(),
psser.reindex_like(psser2).sort_index(),
)
self.assertRaises(TypeError, lambda: psser.reindex_like(index2))
@ -521,7 +525,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser = ps.from_pandas(pser)
self.assert_eq(
pser.reindex_like(pser2).sort_index(), psser.reindex_like(psser2).sort_index(),
pser.reindex_like(pser2).sort_index(),
psser.reindex_like(psser2).sort_index(),
)
# Reindexing with DataFrame
@ -532,7 +537,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psdf = ps.from_pandas(pdf)
self.assert_eq(
pser.reindex_like(pdf).sort_index(), psser.reindex_like(psdf).sort_index(),
pser.reindex_like(pdf).sort_index(),
psser.reindex_like(psdf).sort_index(),
)
def test_fillna(self):
@ -2897,19 +2903,22 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
pser = pd.Series([1, 2, 3, 4], index=idx)
psser = ps.from_pandas(pser)
self.assert_eq(
pser.at_time("0:20").sort_index(), psser.at_time("0:20").sort_index(),
pser.at_time("0:20").sort_index(),
psser.at_time("0:20").sort_index(),
)
pser.index.name = "ts"
psser = ps.from_pandas(pser)
self.assert_eq(
pser.at_time("0:20").sort_index(), psser.at_time("0:20").sort_index(),
pser.at_time("0:20").sort_index(),
psser.at_time("0:20").sort_index(),
)
pser.index.name = "index"
psser = ps.from_pandas(pser)
self.assert_eq(
pser.at_time("0:20").sort_index(), psser.at_time("0:20").sort_index(),
pser.at_time("0:20").sort_index(),
psser.at_time("0:20").sort_index(),
)