[SPARK-35499][PYTHON] Apply black to pandas API on Spark codes

### What changes were proposed in this pull request?

This PR proposes applying `black` to pandas API on Spark codes, for improving static analysis.

By executing the `./dev/reformat-python` in the spark home directory, all the code of the pandas API on Spark is fixed according to the static analysis rules.

### Why are the changes needed?

This can be reduces the cost of static analysis during development.

It has been used continuously for about a year in the Koalas project and its convenience has been proven.

### Does this PR introduce _any_ user-facing change?

No, it's dev-only.

### How was this patch tested?

Manually reformat the pandas API on Spark codes by running the `./dev/reformat-python`, and checked the `./dev/lint-python` is passed.

Closes #32779 from itholic/SPARK-35499.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
itholic 2021-06-06 17:30:07 -07:00 committed by Liang-Chi Hsieh
parent d4e32c896a
commit b8740a1d1e
54 changed files with 428 additions and 228 deletions

View file

@ -366,7 +366,7 @@ jobs:
# See also https://github.com/sphinx-doc/sphinx/issues/7551. # See also https://github.com/sphinx-doc/sphinx/issues/7551.
# Jinja2 3.0.0+ causes error when building with Sphinx. # Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375. # See also https://issues.apache.org/jira/browse/SPARK-35375.
python3.6 -m pip install flake8 pydata_sphinx_theme mypy numpydoc 'jinja2<3.0.0' python3.6 -m pip install flake8 pydata_sphinx_theme mypy numpydoc 'jinja2<3.0.0' 'black==21.5b2'
- name: Install R linter dependencies and SparkR - name: Install R linter dependencies and SparkR
run: | run: |
apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev

View file

@ -25,6 +25,8 @@ MINIMUM_PYCODESTYLE="2.7.0"
PYTHON_EXECUTABLE="python3" PYTHON_EXECUTABLE="python3"
BLACK_BUILD="$PYTHON_EXECUTABLE -m black"
function satisfies_min_version { function satisfies_min_version {
local provided_version="$1" local provided_version="$1"
local expected_version="$2" local expected_version="$2"
@ -185,6 +187,35 @@ flake8 checks failed."
fi fi
} }
function black_test {
local BLACK_REPORT=
local BLACK_STATUS=
# Skip check if black is not installed.
$BLACK_BUILD 2> /dev/null
if [ $? -ne 0 ]; then
echo "The $BLACK_BUILD command was not found. Skipping black checks for now."
echo
return
fi
echo "starting black test..."
# Black is only applied for pandas API on Spark for now.
BLACK_REPORT=$( ($BLACK_BUILD python/pyspark/pandas --line-length 100 --check ) 2>&1)
BLACK_STATUS=$?
if [ "$BLACK_STATUS" -ne 0 ]; then
echo "black checks failed:"
echo "$BLACK_REPORT"
echo "Please run 'dev/reformat-python' script."
echo "$BLACK_STATUS"
exit "$BLACK_STATUS"
else
echo "black checks passed."
echo
fi
}
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")" SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
@ -194,6 +225,7 @@ pushd "$SPARK_ROOT_DIR" &> /dev/null
PYTHON_SOURCE="$(find . -path ./docs/.local_ruby_bundle -prune -false -o -name "*.py")" PYTHON_SOURCE="$(find . -path ./docs/.local_ruby_bundle -prune -false -o -name "*.py")"
compile_python_test "$PYTHON_SOURCE" compile_python_test "$PYTHON_SOURCE"
black_test
pycodestyle_test "$PYTHON_SOURCE" pycodestyle_test "$PYTHON_SOURCE"
flake8_test flake8_test
mypy_test mypy_test

32
dev/reformat-python Executable file
View file

@ -0,0 +1,32 @@
#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The current directory of the script.
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
FWDIR="$( cd "$DIR"/.. && pwd )"
cd "$FWDIR"
BLACK_BUILD="python -m black"
BLACK_VERSION="21.5b2"
$BLACK_BUILD 2> /dev/null
if [ $? -ne 0 ]; then
echo "The '$BLACK_BUILD' command was not found. Please install Black, for example, via 'pip install black==$BLACK_VERSION'."
exit 1
fi
# This script is only applied for pandas API on Spark for now.
$BLACK_BUILD python/pyspark/pandas --line-length 100

View file

@ -32,3 +32,6 @@ sphinx-plotly-directive
# Development scripts # Development scripts
jira jira
PyGithub PyGithub
# pandas API on Spark Code formatter.
black

View file

@ -14,11 +14,13 @@
# limitations under the License. # limitations under the License.
[pycodestyle] [pycodestyle]
ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504 ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504
max-line-length=100 max-line-length=100
exclude=*/target/*,python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/* exclude=*/target/*,python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*
[flake8] [flake8]
select = E901,E999,F821,F822,F823,F401,F405,B006 select = E901,E999,F821,F822,F823,F401,F405,B006
# Ignore F821 for plot documents in pandas API on Spark.
ignore = F821
exclude = python/docs/build/html/*,*/target/*,python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi exclude = python/docs/build/html/*,*/target/*,python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi
max-line-length = 100 max-line-length = 100

View file

@ -48,7 +48,7 @@ if TYPE_CHECKING:
class PandasOnSparkFrameMethods(object): class PandasOnSparkFrameMethods(object):
""" pandas-on-Spark specific features for DataFrame. """ """pandas-on-Spark specific features for DataFrame."""
def __init__(self, frame: "DataFrame"): def __init__(self, frame: "DataFrame"):
self._psdf = frame self._psdf = frame
@ -696,7 +696,7 @@ class PandasOnSparkFrameMethods(object):
class PandasOnSparkSeriesMethods(object): class PandasOnSparkSeriesMethods(object):
""" pandas-on-Spark specific features for Series. """ """pandas-on-Spark specific features for Series."""
def __init__(self, series: "Series"): def __init__(self, series: "Series"):
self._psser = series self._psser = series

View file

@ -1068,9 +1068,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
if isinstance(self, MultiIndex): if isinstance(self, MultiIndex):
raise NotImplementedError("notna is not defined for MultiIndex") raise NotImplementedError("notna is not defined for MultiIndex")
return (~self.isnull()).rename( return (~self.isnull()).rename(self.name) # type: ignore
self.name # type: ignore
)
notna = notnull notna = notnull

View file

@ -381,7 +381,7 @@ def _check_option(key: str) -> None:
class DictWrapper: class DictWrapper:
""" provide attribute-style access to a nested dict""" """provide attribute-style access to a nested dict"""
def __init__(self, d: Dict[str, Option], prefix: str = ""): def __init__(self, d: Dict[str, Option], prefix: str = ""):
object.__setattr__(self, "d", d) object.__setattr__(self, "d", d)

View file

@ -45,11 +45,7 @@ if TYPE_CHECKING:
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943) from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
def is_valid_operand_for_numeric_arithmetic( def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool = True) -> bool:
operand: Any,
*,
allow_bool: bool = True
) -> bool:
"""Check whether the operand is valid for arithmetic operations against numerics.""" """Check whether the operand is valid for arithmetic operations against numerics."""
if isinstance(operand, numbers.Number) and not isinstance(operand, bool): if isinstance(operand, numbers.Number) and not isinstance(operand, bool):
return True return True
@ -58,7 +54,8 @@ def is_valid_operand_for_numeric_arithmetic(
return False return False
else: else:
return isinstance(operand.spark.data_type, NumericType) or ( return isinstance(operand.spark.data_type, NumericType) or (
allow_bool and isinstance(operand.spark.data_type, BooleanType)) allow_bool and isinstance(operand.spark.data_type, BooleanType)
)
else: else:
return False return False

View file

@ -34,7 +34,7 @@ class BinaryOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'binaries' return "binaries"
def add(self, left, right) -> Union["Series", "Index"]: def add(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, BinaryType): if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, BinaryType):
@ -43,11 +43,13 @@ class BinaryOps(DataTypeOps):
return column_op(F.concat)(left, F.lit(right)) return column_op(F.concat)(left, F.lit(right))
else: else:
raise TypeError( raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name) "Concatenation can not be applied to %s and the given type." % self.pretty_name
)
def radd(self, left, right) -> Union["Series", "Index"]: def radd(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, bytes): if isinstance(right, bytes):
return left._with_new_scol(F.concat(F.lit(right), left.spark.column)) return left._with_new_scol(F.concat(F.lit(right), left.spark.column))
else: else:
raise TypeError( raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name) "Concatenation can not be applied to %s and the given type." % self.pretty_name
)

View file

@ -38,12 +38,13 @@ class BooleanOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'booleans' return "booleans"
def add(self, left, right) -> Union["Series", "Index"]: def add(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"Addition can not be applied to %s and the given type." % self.pretty_name) "Addition can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
@ -56,7 +57,8 @@ class BooleanOps(DataTypeOps):
def sub(self, left, right) -> Union["Series", "Index"]: def sub(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"Subtraction can not be applied to %s and the given type." % self.pretty_name) "Subtraction can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left - right return left - right
@ -68,7 +70,8 @@ class BooleanOps(DataTypeOps):
def mul(self, left, right) -> Union["Series", "Index"]: def mul(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"Multiplication can not be applied to %s and the given type." % self.pretty_name) "Multiplication can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left * right return left * right
@ -80,7 +83,8 @@ class BooleanOps(DataTypeOps):
def truediv(self, left, right) -> Union["Series", "Index"]: def truediv(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"True division can not be applied to %s and the given type." % self.pretty_name) "True division can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left / right return left / right
@ -92,7 +96,8 @@ class BooleanOps(DataTypeOps):
def floordiv(self, left, right) -> Union["Series", "Index"]: def floordiv(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"Floor division can not be applied to %s and the given type." % self.pretty_name) "Floor division can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left // right return left // right
@ -104,7 +109,8 @@ class BooleanOps(DataTypeOps):
def mod(self, left, right) -> Union["Series", "Index"]: def mod(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"Modulo can not be applied to %s and the given type." % self.pretty_name) "Modulo can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left % right return left % right
@ -116,7 +122,8 @@ class BooleanOps(DataTypeOps):
def pow(self, left, right) -> Union["Series", "Index"]: def pow(self, left, right) -> Union["Series", "Index"]:
if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False): if not is_valid_operand_for_numeric_arithmetic(right, allow_bool=False):
raise TypeError( raise TypeError(
"Exponentiation can not be applied to %s and the given type." % self.pretty_name) "Exponentiation can not be applied to %s and the given type." % self.pretty_name
)
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right)))) left = left.spark.transform(lambda scol: scol.cast(as_spark_type(type(right))))
return left ** right return left ** right
@ -131,7 +138,8 @@ class BooleanOps(DataTypeOps):
return right + left return right + left
else: else:
raise TypeError( raise TypeError(
"Addition can not be applied to %s and the given type." % self.pretty_name) "Addition can not be applied to %s and the given type." % self.pretty_name
)
def rsub(self, left, right) -> Union["Series", "Index"]: def rsub(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -139,7 +147,8 @@ class BooleanOps(DataTypeOps):
return right - left return right - left
else: else:
raise TypeError( raise TypeError(
"Subtraction can not be applied to %s and the given type." % self.pretty_name) "Subtraction can not be applied to %s and the given type." % self.pretty_name
)
def rmul(self, left, right) -> Union["Series", "Index"]: def rmul(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -147,7 +156,8 @@ class BooleanOps(DataTypeOps):
return right * left return right * left
else: else:
raise TypeError( raise TypeError(
"Multiplication can not be applied to %s and the given type." % self.pretty_name) "Multiplication can not be applied to %s and the given type." % self.pretty_name
)
def rtruediv(self, left, right) -> Union["Series", "Index"]: def rtruediv(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -155,7 +165,8 @@ class BooleanOps(DataTypeOps):
return right / left return right / left
else: else:
raise TypeError( raise TypeError(
"True division can not be applied to %s and the given type." % self.pretty_name) "True division can not be applied to %s and the given type." % self.pretty_name
)
def rfloordiv(self, left, right) -> Union["Series", "Index"]: def rfloordiv(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -163,7 +174,8 @@ class BooleanOps(DataTypeOps):
return right // left return right // left
else: else:
raise TypeError( raise TypeError(
"Floor division can not be applied to %s and the given type." % self.pretty_name) "Floor division can not be applied to %s and the given type." % self.pretty_name
)
def rpow(self, left, right) -> Union["Series", "Index"]: def rpow(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -171,7 +183,8 @@ class BooleanOps(DataTypeOps):
return right ** left return right ** left
else: else:
raise TypeError( raise TypeError(
"Exponentiation can not be applied to %s and the given type." % self.pretty_name) "Exponentiation can not be applied to %s and the given type." % self.pretty_name
)
def rmod(self, left, right) -> Union["Series", "Index"]: def rmod(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, numbers.Number) and not isinstance(right, bool): if isinstance(right, numbers.Number) and not isinstance(right, bool):
@ -179,4 +192,5 @@ class BooleanOps(DataTypeOps):
return right % left return right % left
else: else:
raise TypeError( raise TypeError(
"Modulo can not be applied to %s and the given type." % self.pretty_name) "Modulo can not be applied to %s and the given type." % self.pretty_name
)

View file

@ -25,4 +25,4 @@ class CategoricalOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'categoricals' return "categoricals"

View file

@ -34,22 +34,25 @@ class ArrayOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'arrays' return "arrays"
def add(self, left, right) -> Union["Series", "Index"]: def add(self, left, right) -> Union["Series", "Index"]:
if not isinstance(right, IndexOpsMixin) or ( if not isinstance(right, IndexOpsMixin) or (
isinstance(right, IndexOpsMixin) and not isinstance(right.spark.data_type, ArrayType) isinstance(right, IndexOpsMixin) and not isinstance(right.spark.data_type, ArrayType)
): ):
raise TypeError( raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name) "Concatenation can not be applied to %s and the given type." % self.pretty_name
)
left_type = left.spark.data_type.elementType left_type = left.spark.data_type.elementType
right_type = right.spark.data_type.elementType right_type = right.spark.data_type.elementType
if left_type != right_type and not ( if left_type != right_type and not (
isinstance(left_type, NumericType) and isinstance(right_type, NumericType)): isinstance(left_type, NumericType) and isinstance(right_type, NumericType)
):
raise TypeError( raise TypeError(
"Concatenation can only be applied to %s of the same type" % self.pretty_name) "Concatenation can only be applied to %s of the same type" % self.pretty_name
)
return column_op(F.concat)(left, right) return column_op(F.concat)(left, right)
@ -61,7 +64,7 @@ class MapOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'maps' return "maps"
class StructOps(DataTypeOps): class StructOps(DataTypeOps):
@ -71,4 +74,4 @@ class StructOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'structs' return "structs"

View file

@ -37,7 +37,7 @@ class DateOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'dates' return "dates"
def sub(self, left, right) -> Union["Series", "Index"]: def sub(self, left, right) -> Union["Series", "Index"]:
# Note that date subtraction casts arguments to integer. This is to mimic pandas's # Note that date subtraction casts arguments to integer. This is to mimic pandas's

View file

@ -38,7 +38,7 @@ class DatetimeOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'datetimes' return "datetimes"
def sub(self, left, right) -> Union["Series", "Index"]: def sub(self, left, right) -> Union["Series", "Index"]:
# Note that timestamp subtraction casts arguments to integer. This is to mimic pandas's # Note that timestamp subtraction casts arguments to integer. This is to mimic pandas's

View file

@ -46,7 +46,7 @@ class NumericOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'numerics' return "numerics"
def add(self, left, right) -> Union["Series", "Index"]: def add(self, left, right) -> Union["Series", "Index"]:
if ( if (
@ -159,7 +159,7 @@ class IntegralOps(NumericOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'integrals' return "integrals"
def mul(self, left, right) -> Union["Series", "Index"]: def mul(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, str): if isinstance(right, str):
@ -211,9 +211,7 @@ class IntegralOps(NumericOps):
return F.when(F.lit(right is np.nan), np.nan).otherwise( return F.when(F.lit(right is np.nan), np.nan).otherwise(
F.when( F.when(
F.lit(right != 0) | F.lit(right).isNull(), F.floor(left.__div__(right)) F.lit(right != 0) | F.lit(right).isNull(), F.floor(left.__div__(right))
).otherwise( ).otherwise(F.lit(np.inf).__div__(left))
F.lit(np.inf).__div__(left)
)
) )
return numpy_column_op(floordiv)(left, right) return numpy_column_op(floordiv)(left, right)
@ -253,7 +251,7 @@ class FractionalOps(NumericOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'fractions' return "fractions"
def mul(self, left, right) -> Union["Series", "Index"]: def mul(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, str): if isinstance(right, str):

View file

@ -38,7 +38,7 @@ class StringOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return 'strings' return "strings"
def add(self, left, right) -> Union["Series", "Index"]: def add(self, left, right) -> Union["Series", "Index"]:
if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, StringType): if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, StringType):

View file

@ -124,7 +124,9 @@ def _register_accessor(
) )
warnings.warn( warnings.warn(
msg, UserWarning, stacklevel=2, msg,
UserWarning,
stacklevel=2,
) )
setattr(cls, name, CachedAccessor(name, accessor)) setattr(cls, name, CachedAccessor(name, accessor))
return accessor return accessor

View file

@ -512,7 +512,7 @@ class DataFrame(Frame, Generic[T]):
@property @property
def _pssers(self): def _pssers(self):
""" Return a dict of column label -> Series which anchors `self`. """ """Return a dict of column label -> Series which anchors `self`."""
from pyspark.pandas.series import Series from pyspark.pandas.series import Series
if not hasattr(self, "_psseries"): if not hasattr(self, "_psseries"):
@ -2945,10 +2945,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
else: else:
index_spark_columns = ( index_spark_columns = (
internal.index_spark_columns[:level] internal.index_spark_columns[:level]
+ internal.index_spark_columns[level + len(key):] + internal.index_spark_columns[level + len(key) :]
) )
index_names = internal.index_names[:level] + internal.index_names[level + len(key):] index_names = internal.index_names[:level] + internal.index_names[level + len(key) :]
index_dtypes = internal.index_dtypes[:level] + internal.index_dtypes[level + len(key):] index_dtypes = internal.index_dtypes[:level] + internal.index_dtypes[level + len(key) :]
internal = internal.copy( internal = internal.copy(
index_spark_columns=index_spark_columns, index_spark_columns=index_spark_columns,
@ -5445,7 +5445,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return psdf return psdf
def replace( def replace(
self, to_replace=None, value=None, inplace=False, limit=None, regex=False, method="pad", self,
to_replace=None,
value=None,
inplace=False,
limit=None,
regex=False,
method="pad",
) -> Optional["DataFrame"]: ) -> Optional["DataFrame"]:
""" """
Returns a new DataFrame replacing a value with another value. Returns a new DataFrame replacing a value with another value.
@ -7110,7 +7116,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self._internal.index_dtypes, self._internal.index_dtypes,
) )
) )
index_map[i], index_map[j], = index_map[j], index_map[i] index_map[i], index_map[j], = (
index_map[j],
index_map[i],
)
index_spark_columns, index_names, index_dtypes = zip(*index_map) index_spark_columns, index_names, index_dtypes = zip(*index_map)
internal = self._internal.copy( internal = self._internal.copy(
index_spark_columns=list(index_spark_columns), index_spark_columns=list(index_spark_columns),
@ -7620,7 +7629,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
column_labels.append(label) column_labels.append(label)
for label in right_internal.column_labels: for label in right_internal.column_labels:
# recover `right_prefix` here. # recover `right_prefix` here.
col = right_internal.spark_column_name_for(label)[len(right_prefix):] col = right_internal.spark_column_name_for(label)[len(right_prefix) :]
scol = right_scol_for(label).alias(col) scol = right_scol_for(label).alias(col)
if label in duplicate_columns: if label in duplicate_columns:
spark_column_name = left_internal.spark_column_name_for(label) spark_column_name = left_internal.spark_column_name_for(label)

View file

@ -1806,7 +1806,9 @@ class GroupBy(object, metaclass=ABCMeta):
] ]
psdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply( psdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply(
psdf, self._groupkeys, agg_columns, psdf,
self._groupkeys,
agg_columns,
) )
groupkey_scols = [psdf._internal.spark_column_for(label) for label in groupkey_labels] groupkey_scols = [psdf._internal.spark_column_for(label) for label in groupkey_labels]

View file

@ -447,7 +447,10 @@ class MultiIndex(Index):
self._internal.index_dtypes, self._internal.index_dtypes,
) )
) )
index_map[i], index_map[j], = index_map[j], index_map[i] index_map[i], index_map[j], = (
index_map[j],
index_map[i],
)
index_spark_columns, index_names, index_dtypes = zip(*index_map) index_spark_columns, index_names, index_dtypes = zip(*index_map)
internal = self._internal.copy( internal = self._internal.copy(
index_spark_columns=list(index_spark_columns), index_spark_columns=list(index_spark_columns),
@ -684,9 +687,7 @@ class MultiIndex(Index):
raise NotImplementedError("nunique is not defined for MultiIndex") raise NotImplementedError("nunique is not defined for MultiIndex")
# TODO: add 'name' parameter after pd.MultiIndex.name is implemented # TODO: add 'name' parameter after pd.MultiIndex.name is implemented
def copy( # type: ignore[override] def copy(self, deep: Optional[bool] = None) -> "MultiIndex": # type: ignore[override]
self, deep: Optional[bool] = None
) -> "MultiIndex":
""" """
Make a copy of this object. Make a copy of this object.

View file

@ -333,35 +333,35 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
def _select_rows_by_series( def _select_rows_by_series(
self, rows_sel: "Series" self, rows_sel: "Series"
) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]: ) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]:
""" Select rows by `Series` type key. """ """Select rows by `Series` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_rows_by_spark_column( def _select_rows_by_spark_column(
self, rows_sel: spark.Column self, rows_sel: spark.Column
) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]: ) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]:
""" Select rows by Spark `Column` type key. """ """Select rows by Spark `Column` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_rows_by_slice( def _select_rows_by_slice(
self, rows_sel: slice self, rows_sel: slice
) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]: ) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]:
""" Select rows by `slice` type key. """ """Select rows by `slice` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_rows_by_iterable( def _select_rows_by_iterable(
self, rows_sel: Iterable self, rows_sel: Iterable
) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]: ) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]:
""" Select rows by `Iterable` type key. """ """Select rows by `Iterable` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_rows_else( def _select_rows_else(
self, rows_sel: Any self, rows_sel: Any
) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]: ) -> Tuple[Optional[spark.Column], Optional[int], Optional[int]]:
""" Select rows by other type key. """ """Select rows by other type key."""
pass pass
# Methods for col selection # Methods for col selection
@ -372,7 +372,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
) -> Tuple[ ) -> Tuple[
List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple] List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple]
]: ]:
""" Select columns by `Series` type key. """ """Select columns by `Series` type key."""
pass pass
@abstractmethod @abstractmethod
@ -381,7 +381,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
) -> Tuple[ ) -> Tuple[
List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple] List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple]
]: ]:
""" Select columns by Spark `Column` type key. """ """Select columns by Spark `Column` type key."""
pass pass
@abstractmethod @abstractmethod
@ -390,7 +390,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
) -> Tuple[ ) -> Tuple[
List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple] List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple]
]: ]:
""" Select columns by `slice` type key. """ """Select columns by `slice` type key."""
pass pass
@abstractmethod @abstractmethod
@ -399,7 +399,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
) -> Tuple[ ) -> Tuple[
List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple] List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple]
]: ]:
""" Select columns by `Iterable` type key. """ """Select columns by `Iterable` type key."""
pass pass
@abstractmethod @abstractmethod
@ -408,7 +408,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
) -> Tuple[ ) -> Tuple[
List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple] List[Tuple], Optional[List[spark.Column]], Optional[List[Dtype]], bool, Optional[Tuple]
]: ]:
""" Select columns by other type key. """ """Select columns by other type key."""
pass pass
def __getitem__(self, key) -> Union["Series", "DataFrame"]: def __getitem__(self, key) -> Union["Series", "DataFrame"]:
@ -1140,7 +1140,7 @@ class LocIndexer(LocIndexerLike):
def _get_from_multiindex_column( def _get_from_multiindex_column(
self, key, missing_keys, labels=None, recursed=0 self, key, missing_keys, labels=None, recursed=0
) -> Tuple[List[Tuple], Optional[List[spark.Column]], Any, bool, Optional[Tuple]]: ) -> Tuple[List[Tuple], Optional[List[spark.Column]], Any, bool, Optional[Tuple]]:
""" Select columns from multi-index columns. """ """Select columns from multi-index columns."""
assert isinstance(key, tuple) assert isinstance(key, tuple)
if labels is None: if labels is None:
labels = [(label, label) for label in self._internal.column_labels] labels = [(label, label) for label in self._internal.column_labels]

View file

@ -803,7 +803,7 @@ class InternalFrame(object):
) )
def spark_column_for(self, label: Tuple) -> spark.Column: def spark_column_for(self, label: Tuple) -> spark.Column:
""" Return Spark Column for the given column label. """ """Return Spark Column for the given column label."""
column_labels_to_scol = dict(zip(self.column_labels, self.data_spark_columns)) column_labels_to_scol = dict(zip(self.column_labels, self.data_spark_columns))
if label in column_labels_to_scol: if label in column_labels_to_scol:
return column_labels_to_scol[label] return column_labels_to_scol[label]
@ -811,7 +811,7 @@ class InternalFrame(object):
raise KeyError(name_like_string(label)) raise KeyError(name_like_string(label))
def spark_column_name_for(self, label_or_scol: Union[Tuple, spark.Column]) -> str: def spark_column_name_for(self, label_or_scol: Union[Tuple, spark.Column]) -> str:
""" Return the actual Spark column name for the given column label. """ """Return the actual Spark column name for the given column label."""
if isinstance(label_or_scol, spark.Column): if isinstance(label_or_scol, spark.Column):
scol = label_or_scol scol = label_or_scol
else: else:
@ -819,7 +819,7 @@ class InternalFrame(object):
return self.spark_frame.select(scol).columns[0] return self.spark_frame.select(scol).columns[0]
def spark_type_for(self, label_or_scol: Union[Tuple, spark.Column]) -> DataType: def spark_type_for(self, label_or_scol: Union[Tuple, spark.Column]) -> DataType:
""" Return DataType for the given column label. """ """Return DataType for the given column label."""
if isinstance(label_or_scol, spark.Column): if isinstance(label_or_scol, spark.Column):
scol = label_or_scol scol = label_or_scol
else: else:
@ -827,7 +827,7 @@ class InternalFrame(object):
return self.spark_frame.select(scol).schema[0].dataType return self.spark_frame.select(scol).schema[0].dataType
def spark_column_nullable_for(self, label_or_scol: Union[Tuple, spark.Column]) -> bool: def spark_column_nullable_for(self, label_or_scol: Union[Tuple, spark.Column]) -> bool:
""" Return nullability for the given column label. """ """Return nullability for the given column label."""
if isinstance(label_or_scol, spark.Column): if isinstance(label_or_scol, spark.Column):
scol = label_or_scol scol = label_or_scol
else: else:
@ -835,7 +835,7 @@ class InternalFrame(object):
return self.spark_frame.select(scol).schema[0].nullable return self.spark_frame.select(scol).schema[0].nullable
def dtype_for(self, label: Tuple) -> Dtype: def dtype_for(self, label: Tuple) -> Dtype:
""" Return dtype for the given column label. """ """Return dtype for the given column label."""
column_labels_to_dtype = dict(zip(self.column_labels, self.data_dtypes)) column_labels_to_dtype = dict(zip(self.column_labels, self.data_dtypes))
if label in column_labels_to_dtype: if label in column_labels_to_dtype:
return column_labels_to_dtype[label] return column_labels_to_dtype[label]
@ -844,80 +844,77 @@ class InternalFrame(object):
@property @property
def spark_frame(self) -> spark.DataFrame: def spark_frame(self) -> spark.DataFrame:
""" Return the managed Spark DataFrame. """ """Return the managed Spark DataFrame."""
return self._sdf return self._sdf
@lazy_property @lazy_property
def data_spark_column_names(self) -> List[str]: def data_spark_column_names(self) -> List[str]:
""" Return the managed column field names. """ """Return the managed column field names."""
return self.spark_frame.select(self.data_spark_columns).columns return self.spark_frame.select(self.data_spark_columns).columns
@property @property
def data_spark_columns(self) -> List[spark.Column]: def data_spark_columns(self) -> List[spark.Column]:
""" Return Spark Columns for the managed data columns. """ """Return Spark Columns for the managed data columns."""
return self._data_spark_columns return self._data_spark_columns
@property @property
def index_spark_column_names(self) -> List[str]: def index_spark_column_names(self) -> List[str]:
""" Return the managed index field names. """ """Return the managed index field names."""
return self.spark_frame.select(self.index_spark_columns).columns return self.spark_frame.select(self.index_spark_columns).columns
@property @property
def index_spark_columns(self) -> List[spark.Column]: def index_spark_columns(self) -> List[spark.Column]:
""" Return Spark Columns for the managed index columns. """ """Return Spark Columns for the managed index columns."""
return self._index_spark_columns return self._index_spark_columns
@lazy_property @lazy_property
def spark_column_names(self) -> List[str]: def spark_column_names(self) -> List[str]:
""" Return all the field names including index field names. """ """Return all the field names including index field names."""
return self.spark_frame.select(self.spark_columns).columns return self.spark_frame.select(self.spark_columns).columns
@lazy_property @lazy_property
def spark_columns(self) -> List[spark.Column]: def spark_columns(self) -> List[spark.Column]:
""" Return Spark Columns for the managed columns including index columns. """ """Return Spark Columns for the managed columns including index columns."""
index_spark_columns = self.index_spark_columns index_spark_columns = self.index_spark_columns
return index_spark_columns + [ return index_spark_columns + [
spark_column spark_column
for spark_column in self.data_spark_columns for spark_column in self.data_spark_columns
if all( if all(not spark_column_equals(spark_column, scol) for scol in index_spark_columns)
not spark_column_equals(spark_column, scol)
for scol in index_spark_columns
)
] ]
@property @property
def index_names(self) -> List[Optional[Tuple]]: def index_names(self) -> List[Optional[Tuple]]:
""" Return the managed index names. """ """Return the managed index names."""
return self._index_names return self._index_names
@lazy_property @lazy_property
def index_level(self) -> int: def index_level(self) -> int:
""" Return the level of the index. """ """Return the level of the index."""
return len(self._index_names) return len(self._index_names)
@property @property
def column_labels(self) -> List[Tuple]: def column_labels(self) -> List[Tuple]:
""" Return the managed column index. """ """Return the managed column index."""
return self._column_labels return self._column_labels
@lazy_property @lazy_property
def column_labels_level(self) -> int: def column_labels_level(self) -> int:
""" Return the level of the column index. """ """Return the level of the column index."""
return len(self._column_label_names) return len(self._column_label_names)
@property @property
def column_label_names(self) -> List[Optional[Tuple]]: def column_label_names(self) -> List[Optional[Tuple]]:
""" Return names of the index levels. """ """Return names of the index levels."""
return self._column_label_names return self._column_label_names
@property @property
def index_dtypes(self) -> List[Dtype]: def index_dtypes(self) -> List[Dtype]:
""" Return dtypes for the managed index columns. """ """Return dtypes for the managed index columns."""
return self._index_dtypes return self._index_dtypes
@property @property
def data_dtypes(self) -> List[Dtype]: def data_dtypes(self) -> List[Dtype]:
""" Return dtypes for the managed columns. """ """Return dtypes for the managed columns."""
return self._data_dtypes return self._data_dtypes
@lazy_property @lazy_property
@ -929,16 +926,13 @@ class InternalFrame(object):
index_spark_columns = self.index_spark_columns index_spark_columns = self.index_spark_columns
data_columns = [] data_columns = []
for spark_column in self.data_spark_columns: for spark_column in self.data_spark_columns:
if all( if all(not spark_column_equals(spark_column, scol) for scol in index_spark_columns):
not spark_column_equals(spark_column, scol)
for scol in index_spark_columns
):
data_columns.append(spark_column) data_columns.append(spark_column)
return self.spark_frame.select(index_spark_columns + data_columns) return self.spark_frame.select(index_spark_columns + data_columns)
@lazy_property @lazy_property
def to_pandas_frame(self) -> pd.DataFrame: def to_pandas_frame(self) -> pd.DataFrame:
""" Return as pandas DataFrame. """ """Return as pandas DataFrame."""
sdf = self.to_internal_spark_frame sdf = self.to_internal_spark_frame
pdf = sdf.toPandas() pdf = sdf.toPandas()
if len(pdf) == 0 and len(sdf.schema) > 0: if len(pdf) == 0 and len(sdf.schema) > 0:
@ -950,7 +944,7 @@ class InternalFrame(object):
@lazy_property @lazy_property
def arguments_for_restore_index(self) -> Dict: def arguments_for_restore_index(self) -> Dict:
""" Create arguments for `restore_index`. """ """Create arguments for `restore_index`."""
column_names = [] column_names = []
ext_dtypes = { ext_dtypes = {
col: dtype col: dtype
@ -1055,14 +1049,15 @@ class InternalFrame(object):
pdf.columns = pd.MultiIndex.from_tuples(column_labels, names=names) pdf.columns = pd.MultiIndex.from_tuples(column_labels, names=names)
else: else:
pdf.columns = pd.Index( pdf.columns = pd.Index(
[None if label is None else label[0] for label in column_labels], name=names[0], [None if label is None else label[0] for label in column_labels],
name=names[0],
) )
return pdf return pdf
@lazy_property @lazy_property
def resolved_copy(self) -> "InternalFrame": def resolved_copy(self) -> "InternalFrame":
""" Copy the immutable InternalFrame with the updates resolved. """ """Copy the immutable InternalFrame with the updates resolved."""
sdf = self.spark_frame.select(self.spark_columns + list(HIDDEN_COLUMNS)) sdf = self.spark_frame.select(self.spark_columns + list(HIDDEN_COLUMNS))
return self.copy( return self.copy(
spark_frame=sdf, spark_frame=sdf,
@ -1078,7 +1073,7 @@ class InternalFrame(object):
data_columns: Optional[List[str]] = None, data_columns: Optional[List[str]] = None,
data_dtypes: Optional[List[Dtype]] = None data_dtypes: Optional[List[Dtype]] = None
) -> "InternalFrame": ) -> "InternalFrame":
""" Copy the immutable InternalFrame with the updates by the specified Spark DataFrame. """Copy the immutable InternalFrame with the updates by the specified Spark DataFrame.
:param spark_frame: the new Spark DataFrame :param spark_frame: the new Spark DataFrame
:param index_dtypes: the index dtypes. If None, the original dtyeps are used. :param index_dtypes: the index dtypes. If None, the original dtyeps are used.
@ -1207,7 +1202,7 @@ class InternalFrame(object):
) )
def with_filter(self, pred: Union[spark.Column, "Series"]) -> "InternalFrame": def with_filter(self, pred: Union[spark.Column, "Series"]) -> "InternalFrame":
""" Copy the immutable InternalFrame with the updates by the predicate. """Copy the immutable InternalFrame with the updates by the predicate.
:param pred: the predicate to filter. :param pred: the predicate to filter.
:return: the copied InternalFrame. :return: the copied InternalFrame.
@ -1280,7 +1275,7 @@ class InternalFrame(object):
data_dtypes: Union[Optional[List[Dtype]], _NoValueType] = _NoValue, data_dtypes: Union[Optional[List[Dtype]], _NoValueType] = _NoValue,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue
) -> "InternalFrame": ) -> "InternalFrame":
""" Copy the immutable InternalFrame. """Copy the immutable InternalFrame.
:param spark_frame: the new Spark DataFrame. If not specified, the original one is used. :param spark_frame: the new Spark DataFrame. If not specified, the original one is used.
:param index_spark_columns: the list of Spark Column. :param index_spark_columns: the list of Spark Column.
@ -1324,7 +1319,7 @@ class InternalFrame(object):
@staticmethod @staticmethod
def from_pandas(pdf: pd.DataFrame) -> "InternalFrame": def from_pandas(pdf: pd.DataFrame) -> "InternalFrame":
""" Create an immutable DataFrame from pandas DataFrame. """Create an immutable DataFrame from pandas DataFrame.
:param pdf: :class:`pd.DataFrame` :param pdf: :class:`pd.DataFrame`
:return: the created immutable DataFrame :return: the created immutable DataFrame
@ -1354,7 +1349,9 @@ class InternalFrame(object):
schema = StructType( schema = StructType(
[ [
StructField( StructField(
name, infer_pd_series_spark_type(col, dtype), nullable=bool(col.isnull().any()), name,
infer_pd_series_spark_type(col, dtype),
nullable=bool(col.isnull().any()),
) )
for (name, col), dtype in zip(pdf.iteritems(), index_dtypes + data_dtypes) for (name, col), dtype in zip(pdf.iteritems(), index_dtypes + data_dtypes)
] ]

View file

@ -1903,7 +1903,7 @@ def get_dummies(
raise KeyError(name_like_string(columns)) raise KeyError(name_like_string(columns))
if prefix is None: if prefix is None:
prefix = [ prefix = [
str(label[len(columns):]) str(label[len(columns) :])
if len(label) > len(columns) + 1 if len(label) > len(columns) + 1
else label[len(columns)] else label[len(columns)]
if len(label) == len(columns) + 1 if len(label) == len(columns) + 1
@ -2212,11 +2212,19 @@ def concat(objs, axis=0, join="outer", ignore_index=False, sort=False) -> Union[
for psdf in psdfs_not_same_anchor: for psdf in psdfs_not_same_anchor:
if join == "inner": if join == "inner":
concat_psdf = align_diff_frames( concat_psdf = align_diff_frames(
resolve_func, concat_psdf, psdf, fillna=False, how="inner", resolve_func,
concat_psdf,
psdf,
fillna=False,
how="inner",
) )
elif join == "outer": elif join == "outer":
concat_psdf = align_diff_frames( concat_psdf = align_diff_frames(
resolve_func, concat_psdf, psdf, fillna=False, how="full", resolve_func,
concat_psdf,
psdf,
fillna=False,
how="full",
) )
concat_psdf = concat_psdf[column_labels] concat_psdf = concat_psdf[column_labels]

View file

@ -374,11 +374,19 @@ class KdePlotBase:
if ind is None: if ind is None:
min_val, max_val = calc_min_max() min_val, max_val = calc_min_max()
sample_range = max_val - min_val sample_range = max_val - min_val
ind = np.linspace(min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, 1000,) ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
1000,
)
elif is_integer(ind): elif is_integer(ind):
min_val, max_val = calc_min_max() min_val, max_val = calc_min_max()
sample_range = max_val - min_val sample_range = max_val - min_val
ind = np.linspace(min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, ind,) ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
ind,
)
return ind return ind
@staticmethod @staticmethod

View file

@ -111,7 +111,7 @@ class PandasOnSparkBoxPlot(PandasBoxPlot, BoxPlotBase):
precision=None, precision=None,
): ):
def update_dict(dictionary, rc_name, properties): def update_dict(dictionary, rc_name, properties):
""" Loads properties in the dictionary from rc file if not already """Loads properties in the dictionary from rc file if not already
in the dictionary""" in the dictionary"""
rc_str = "boxplot.{0}.{1}" rc_str = "boxplot.{0}.{1}"
if dictionary is None: if dictionary is None:

View file

@ -1592,7 +1592,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
else: else:
return first_series(psdf) return first_series(psdf)
def reindex(self, index: Optional[Any] = None, fill_value: Optional[Any] = None,) -> "Series": def reindex(
self,
index: Optional[Any] = None,
fill_value: Optional[Any] = None,
) -> "Series":
""" """
Conform Series to new index with optional filling logic, placing Conform Series to new index with optional filling logic, placing
NA/NaN in locations having no value in the previous index. A new object NA/NaN in locations having no value in the previous index. A new object
@ -3485,7 +3489,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if method == "first": if method == "first":
window = ( window = (
Window.orderBy( Window.orderBy(
asc_func(self.spark.column), asc_func(F.col(NATURAL_ORDER_COLUMN_NAME)), asc_func(self.spark.column),
asc_func(F.col(NATURAL_ORDER_COLUMN_NAME)),
) )
.partitionBy(*part_cols) .partitionBy(*part_cols)
.rowsBetween(Window.unboundedPreceding, Window.currentRow) .rowsBetween(Window.unboundedPreceding, Window.currentRow)
@ -3958,7 +3963,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
) )
internal = self._internal internal = self._internal
scols = internal.index_spark_columns[len(item):] + [self.spark.column] scols = internal.index_spark_columns[len(item) :] + [self.spark.column]
rows = [internal.spark_columns[level] == index for level, index in enumerate(item)] rows = [internal.spark_columns[level] == index for level, index in enumerate(item)]
sdf = internal.spark_frame.filter(reduce(lambda x, y: x & y, rows)).select(scols) sdf = internal.spark_frame.filter(reduce(lambda x, y: x & y, rows)).select(scols)
@ -3985,10 +3990,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
internal = internal.copy( internal = internal.copy(
spark_frame=sdf, spark_frame=sdf,
index_spark_columns=[ index_spark_columns=[
scol_for(sdf, col) for col in internal.index_spark_column_names[len(item):] scol_for(sdf, col) for col in internal.index_spark_column_names[len(item) :]
], ],
index_dtypes=internal.index_dtypes[len(item):], index_dtypes=internal.index_dtypes[len(item) :],
index_names=self._internal.index_names[len(item):], index_names=self._internal.index_names[len(item) :],
data_spark_columns=[scol_for(sdf, internal.data_spark_column_names[0])], data_spark_columns=[scol_for(sdf, internal.data_spark_column_names[0])],
) )
return first_series(DataFrame(internal)) return first_series(DataFrame(internal))
@ -4660,7 +4665,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
internal = self._internal internal = self._internal
scols = ( scols = (
internal.index_spark_columns[:level] internal.index_spark_columns[:level]
+ internal.index_spark_columns[level + len(key):] + internal.index_spark_columns[level + len(key) :]
+ [self.spark.column] + [self.spark.column]
) )
rows = [internal.spark_columns[lvl] == index for lvl, index in enumerate(key, level)] rows = [internal.spark_columns[lvl] == index for lvl, index in enumerate(key, level)]
@ -4675,10 +4680,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
index_spark_column_names = ( index_spark_column_names = (
internal.index_spark_column_names[:level] internal.index_spark_column_names[:level]
+ internal.index_spark_column_names[level + len(key):] + internal.index_spark_column_names[level + len(key) :]
) )
index_names = internal.index_names[:level] + internal.index_names[level + len(key):] index_names = internal.index_names[:level] + internal.index_names[level + len(key) :]
index_dtypes = internal.index_dtypes[:level] + internal.index_dtypes[level + len(key):] index_dtypes = internal.index_dtypes[:level] + internal.index_dtypes[level + len(key) :]
internal = internal.copy( internal = internal.copy(
spark_frame=sdf, spark_frame=sdf,

View file

@ -44,12 +44,12 @@ class SparkIndexOpsMethods(metaclass=ABCMeta):
@property @property
def data_type(self) -> DataType: def data_type(self) -> DataType:
""" Returns the data type as defined by Spark, as a Spark DataType object.""" """Returns the data type as defined by Spark, as a Spark DataType object."""
return self._data._internal.spark_type_for(self._data._column_label) return self._data._internal.spark_type_for(self._data._column_label)
@property @property
def nullable(self) -> bool: def nullable(self) -> bool:
""" Returns the nullability as defined by Spark. """ """Returns the nullability as defined by Spark."""
return self._data._internal.spark_column_nullable_for(self._data._column_label) return self._data._internal.spark_column_nullable_for(self._data._column_label)
@property @property

View file

@ -26,7 +26,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestCase
class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils): class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property @property
def pser(self): def pser(self):
return pd.Series([b'1', b'2', b'3']) return pd.Series([b"1", b"2", b"3"])
@property @property
def psser(self): def psser(self):
@ -35,9 +35,9 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def test_add(self): def test_add(self):
psser = self.psser psser = self.psser
pser = self.pser pser = self.pser
self.assert_eq(psser + b'1', pser + b'1') self.assert_eq(psser + b"1", pser + b"1")
self.assert_eq(psser + psser, pser + pser) self.assert_eq(psser + psser, pser + pser)
self.assert_eq(psser + psser.astype('bytes'), pser + pser.astype('bytes')) self.assert_eq(psser + psser.astype("bytes"), pser + pser.astype("bytes"))
self.assertRaises(TypeError, lambda: psser + "x") self.assertRaises(TypeError, lambda: psser + "x")
self.assertRaises(TypeError, lambda: psser + 1) self.assertRaises(TypeError, lambda: psser + 1)
@ -95,7 +95,7 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: self.psser ** psser) self.assertRaises(TypeError, lambda: self.psser ** psser)
def test_radd(self): def test_radd(self):
self.assert_eq(b'1' + self.psser, b'1' + self.pser) self.assert_eq(b"1" + self.psser, b"1" + self.pser)
self.assertRaises(TypeError, lambda: "x" + self.psser) self.assertRaises(TypeError, lambda: "x" + self.psser)
self.assertRaises(TypeError, lambda: 1 + self.psser) self.assertRaises(TypeError, lambda: 1 + self.psser)
@ -129,7 +129,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -103,7 +103,8 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils):
with option_context("compute.ops_on_diff_frames", True): with option_context("compute.ops_on_diff_frames", True):
self.assert_eq( self.assert_eq(
self.pser / self.float_pser, (self.psser / self.float_psser).sort_index()) self.pser / self.float_pser, (self.psser / self.float_psser).sort_index()
)
for psser in self.non_numeric_pssers.values(): for psser in self.non_numeric_pssers.values():
self.assertRaises(TypeError, lambda: self.psser / psser) self.assertRaises(TypeError, lambda: self.psser / psser)
@ -235,7 +236,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -122,7 +122,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -32,16 +32,17 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
return [ return [
pd.Series([[1, 2, 3]]), pd.Series([[1, 2, 3]]),
pd.Series([[0.1, 0.2, 0.3]]), pd.Series([[0.1, 0.2, 0.3]]),
pd.Series([[decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3)]]) pd.Series([[decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3)]]),
] ]
@property @property
def non_numeric_array_psers(self): def non_numeric_array_psers(self):
return { return {
"string": pd.Series([['x', 'y', 'z']]), "string": pd.Series([["x", "y", "z"]]),
"date": pd.Series([ "date": pd.Series(
[datetime.date(1994, 1, 1), datetime.date(1994, 1, 2), datetime.date(1994, 1, 3)]]), [[datetime.date(1994, 1, 1), datetime.date(1994, 1, 2), datetime.date(1994, 1, 3)]]
"bool": pd.Series([[True, True, False]]) ),
"bool": pd.Series([[True, True, False]]),
} }
@property @property
@ -80,16 +81,19 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
# Non-numeric array + Non-numeric array # Non-numeric array + Non-numeric array
self.assertRaises( self.assertRaises(
TypeError, lambda: TypeError,
self.non_numeric_array_pssers['string'] + self.non_numeric_array_pssers['bool'] lambda: self.non_numeric_array_pssers["string"]
+ self.non_numeric_array_pssers["bool"],
) )
self.assertRaises( self.assertRaises(
TypeError, lambda: TypeError,
self.non_numeric_array_pssers['string'] + self.non_numeric_array_pssers['date'] lambda: self.non_numeric_array_pssers["string"]
+ self.non_numeric_array_pssers["date"],
) )
self.assertRaises( self.assertRaises(
TypeError, lambda: TypeError,
self.non_numeric_array_pssers['bool'] + self.non_numeric_array_pssers['date'] lambda: self.non_numeric_array_pssers["bool"]
+ self.non_numeric_array_pssers["date"],
) )
for data_type in self.non_numeric_array_psers.keys(): for data_type in self.non_numeric_array_psers.keys():
@ -97,7 +101,7 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.non_numeric_array_psers.get(data_type) self.non_numeric_array_psers.get(data_type)
+ self.non_numeric_array_psers.get(data_type), + self.non_numeric_array_psers.get(data_type),
self.non_numeric_array_pssers.get(data_type) self.non_numeric_array_pssers.get(data_type)
+ self.non_numeric_array_pssers.get(data_type) + self.non_numeric_array_pssers.get(data_type),
) )
# Numeric array + Non-numeric array # Numeric array + Non-numeric array
@ -193,7 +197,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -55,7 +55,8 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: self.psser - "x") self.assertRaises(TypeError, lambda: self.psser - "x")
self.assertRaises(TypeError, lambda: self.psser - 1) self.assertRaises(TypeError, lambda: self.psser - 1)
self.assert_eq( self.assert_eq(
(self.pser - self.some_date).dt.days, self.psser - self.some_date, (self.pser - self.some_date).dt.days,
self.psser - self.some_date,
) )
with option_context("compute.ops_on_diff_frames", True): with option_context("compute.ops_on_diff_frames", True):
for pser, psser in self.pser_psser_pairs: for pser, psser in self.pser_psser_pairs:
@ -118,7 +119,8 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: "x" - self.psser) self.assertRaises(TypeError, lambda: "x" - self.psser)
self.assertRaises(TypeError, lambda: 1 - self.psser) self.assertRaises(TypeError, lambda: 1 - self.psser)
self.assert_eq( self.assert_eq(
(self.some_date - self.pser).dt.days, self.some_date - self.psser, (self.some_date - self.pser).dt.days,
self.some_date - self.psser,
) )
def test_rmul(self): def test_rmul(self):
@ -152,7 +154,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -154,7 +154,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -34,6 +34,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
returns float32. returns float32.
The underlying reason is the respective Spark operations return DoubleType always. The underlying reason is the respective Spark operations return DoubleType always.
""" """
@property @property
def float_pser(self): def float_pser(self):
return pd.Series([1, 2, 3], dtype=float) return pd.Series([1, 2, 3], dtype=float)
@ -137,7 +138,8 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: psser // self.non_numeric_pssers["datetime"]) self.assertRaises(TypeError, lambda: psser // self.non_numeric_pssers["datetime"])
self.assertRaises(TypeError, lambda: psser // self.non_numeric_pssers["date"]) self.assertRaises(TypeError, lambda: psser // self.non_numeric_pssers["date"])
self.assertRaises( self.assertRaises(
TypeError, lambda: psser // self.non_numeric_pssers["categorical"]) TypeError, lambda: psser // self.non_numeric_pssers["categorical"]
)
if LooseVersion(pd.__version__) >= LooseVersion("0.25.3"): if LooseVersion(pd.__version__) >= LooseVersion("0.25.3"):
self.assert_eq( self.assert_eq(
(self.float_psser // self.non_numeric_pssers["bool"]).sort_index(), (self.float_psser // self.non_numeric_pssers["bool"]).sort_index(),
@ -146,7 +148,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
else: else:
self.assert_eq( self.assert_eq(
(self.float_pser // self.non_numeric_psers["bool"]).sort_index(), (self.float_pser // self.non_numeric_psers["bool"]).sort_index(),
ps.Series([1.0, 2.0, np.inf]) ps.Series([1.0, 2.0, np.inf]),
) )
self.assertRaises(TypeError, lambda: self.float_psser // True) self.assertRaises(TypeError, lambda: self.float_psser // True)
@ -181,7 +183,8 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: psser ** self.non_numeric_pssers["datetime"]) self.assertRaises(TypeError, lambda: psser ** self.non_numeric_pssers["datetime"])
self.assertRaises(TypeError, lambda: psser ** self.non_numeric_pssers["date"]) self.assertRaises(TypeError, lambda: psser ** self.non_numeric_pssers["date"])
self.assertRaises( self.assertRaises(
TypeError, lambda: psser ** self.non_numeric_pssers["categorical"]) TypeError, lambda: psser ** self.non_numeric_pssers["categorical"]
)
self.assert_eq( self.assert_eq(
(self.float_psser ** self.non_numeric_pssers["bool"]).sort_index(), (self.float_psser ** self.non_numeric_pssers["bool"]).sort_index(),
self.float_pser ** self.non_numeric_psers["bool"], self.float_pser ** self.non_numeric_psers["bool"],
@ -258,7 +261,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -45,7 +45,8 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["datetime"]) self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["datetime"])
self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["date"]) self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["date"])
self.assertRaises( self.assertRaises(
TypeError, lambda: self.psser + self.non_numeric_pssers["categorical"]) TypeError, lambda: self.psser + self.non_numeric_pssers["categorical"]
)
self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["bool"]) self.assertRaises(TypeError, lambda: self.psser + self.non_numeric_pssers["bool"])
for psser in self.numeric_pssers: for psser in self.numeric_pssers:
self.assertRaises(TypeError, lambda: self.psser + psser) self.assertRaises(TypeError, lambda: self.psser + psser)
@ -135,7 +136,8 @@ if __name__ == "__main__":
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError: except ImportError:
testRunner = None testRunner = None
unittest.main(testRunner=testRunner, verbosity=2) unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -26,6 +26,7 @@ import pyspark.pandas as ps
class TestCasesUtils(object): class TestCasesUtils(object):
"""A utility holding common test cases for arithmetic operations of different data types.""" """A utility holding common test cases for arithmetic operations of different data types."""
@property @property
def numeric_psers(self): def numeric_psers(self):
dtypes = [np.float32, float, int, np.int32] dtypes = [np.float32, float, int, np.int32]

View file

@ -174,7 +174,8 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
# The `name` argument is added in pandas 0.24. # The `name` argument is added in pandas 0.24.
self.assert_eq(psidx.to_frame(name="x"), pidx.to_frame(name="x")) self.assert_eq(psidx.to_frame(name="x"), pidx.to_frame(name="x"))
self.assert_eq( self.assert_eq(
psidx.to_frame(index=False, name="x"), pidx.to_frame(index=False, name="x"), psidx.to_frame(index=False, name="x"),
pidx.to_frame(index=False, name="x"),
) )
self.assertRaises(TypeError, lambda: psidx.to_frame(name=["x"])) self.assertRaises(TypeError, lambda: psidx.to_frame(name=["x"]))

View file

@ -425,7 +425,7 @@ class DataFramePlotMatplotlibTest(PandasOnSparkTestCase, TestUtils):
def moving_average(a, n=10): def moving_average(a, n=10):
ret = np.cumsum(a, dtype=float) ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n] ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n return ret[n - 1 :] / n
def check_kde_plot(pdf, psdf, *args, **kwargs): def check_kde_plot(pdf, psdf, *args, **kwargs):
_, ax1 = plt.subplots(1, 1) _, ax1 = plt.subplots(1, 1)

View file

@ -169,7 +169,8 @@ class DataFramePlotPlotlyTest(PandasOnSparkTestCase, TestUtils):
) )
self.assertEqual( self.assertEqual(
psdf.plot(kind="pie", values="a"), express.pie(pdf, values="a"), psdf.plot(kind="pie", values="a"),
express.pie(pdf, values="a"),
) )
psdf1 = self.psdf1 psdf1 = self.psdf1

View file

@ -347,7 +347,7 @@ class SeriesPlotMatplotlibTest(PandasOnSparkTestCase, TestUtils):
def moving_average(a, n=10): def moving_average(a, n=10):
ret = np.cumsum(a, dtype=float) ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n] ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n return ret[n - 1 :] / n
def check_kde_plot(pdf, psdf, *args, **kwargs): def check_kde_plot(pdf, psdf, *args, **kwargs):
_, ax1 = plt.subplots(1, 1) _, ax1 = plt.subplots(1, 1)

View file

@ -118,7 +118,8 @@ class SeriesPlotPlotlyTest(PandasOnSparkTestCase, TestUtils):
psdf = self.psdf1 psdf = self.psdf1
pdf = psdf.to_pandas() pdf = psdf.to_pandas()
self.assertEqual( self.assertEqual(
psdf["a"].plot(kind="pie"), express.pie(pdf, values=pdf.columns[0], names=pdf.index), psdf["a"].plot(kind="pie"),
express.pie(pdf, values=pdf.columns[0], names=pdf.index),
) )
# TODO: support multi-index columns # TODO: support multi-index columns

View file

@ -387,14 +387,16 @@ class CategoricalTest(PandasOnSparkTestCase, TestUtils):
return pdf.astype(str) return pdf.astype(str)
self.assert_eq( self.assert_eq(
psdf.pandas_on_spark.transform_batch(to_str).sort_index(), to_str(pdf).sort_index(), psdf.pandas_on_spark.transform_batch(to_str).sort_index(),
to_str(pdf).sort_index(),
) )
def to_codes(pdf) -> ps.Series[np.int8]: def to_codes(pdf) -> ps.Series[np.int8]:
return pdf.b.cat.codes return pdf.b.cat.codes
self.assert_eq( self.assert_eq(
psdf.pandas_on_spark.transform_batch(to_codes).sort_index(), to_codes(pdf).sort_index(), psdf.pandas_on_spark.transform_batch(to_codes).sort_index(),
to_codes(pdf).sort_index(),
) )
pdf = pd.DataFrame( pdf = pd.DataFrame(

View file

@ -1518,7 +1518,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
# Assert approximate counts # Assert approximate counts
self.assert_eq( self.assert_eq(
ps.DataFrame({"A": range(100)}).nunique(approx=True), pd.Series([103], index=["A"]), ps.DataFrame({"A": range(100)}).nunique(approx=True),
pd.Series([103], index=["A"]),
) )
self.assert_eq( self.assert_eq(
ps.DataFrame({"A": range(100)}).nunique(approx=True, rsd=0.01), ps.DataFrame({"A": range(100)}).nunique(approx=True, rsd=0.01),
@ -3354,7 +3355,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
columns2 = pd.Index(["numbers", "2", "3"], name="cols2") columns2 = pd.Index(["numbers", "2", "3"], name="cols2")
self.assert_eq( self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(), pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
) )
columns = pd.Index(["numbers"], name="cols") columns = pd.Index(["numbers"], name="cols")
@ -3398,12 +3400,14 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
columns2 = pd.Index(["numbers", "2", "3"]) columns2 = pd.Index(["numbers", "2", "3"])
self.assert_eq( self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(), pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
) )
columns2 = pd.Index(["numbers", "2", "3"], name="cols2") columns2 = pd.Index(["numbers", "2", "3"], name="cols2")
self.assert_eq( self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(), pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
) )
# Reindexing single Index on single Index # Reindexing single Index on single Index
@ -3506,7 +3510,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
[("X", "numbers"), ("Y", "2"), ("Y", "3")], names=["cols3", "cols4"] [("X", "numbers"), ("Y", "2"), ("Y", "3")], names=["cols3", "cols4"]
) )
self.assert_eq( self.assert_eq(
pdf.reindex(columns=columns2).sort_index(), psdf.reindex(columns=columns2).sort_index(), pdf.reindex(columns=columns2).sort_index(),
psdf.reindex(columns=columns2).sort_index(),
) )
self.assertRaises(TypeError, lambda: psdf.reindex(columns=["X"])) self.assertRaises(TypeError, lambda: psdf.reindex(columns=["X"]))
@ -3527,7 +3532,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf2 = ps.from_pandas(pdf2) psdf2 = ps.from_pandas(pdf2)
self.assert_eq( self.assert_eq(
pdf.reindex_like(pdf2).sort_index(), psdf.reindex_like(psdf2).sort_index(), pdf.reindex_like(pdf2).sort_index(),
psdf.reindex_like(psdf2).sort_index(),
) )
pdf2 = pd.DataFrame({"index_level_1": ["A", "C", "I"]}) pdf2 = pd.DataFrame({"index_level_1": ["A", "C", "I"]})
@ -3546,7 +3552,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf2 = ps.from_pandas(pdf2) psdf2 = ps.from_pandas(pdf2)
self.assert_eq( self.assert_eq(
pdf.reindex_like(pdf2).sort_index(), psdf.reindex_like(psdf2).sort_index(), pdf.reindex_like(pdf2).sort_index(),
psdf.reindex_like(psdf2).sort_index(),
) )
self.assertRaises(TypeError, lambda: psdf.reindex_like(index2)) self.assertRaises(TypeError, lambda: psdf.reindex_like(index2))
@ -3569,7 +3576,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
pdf.reindex_like(pdf2).sort_index(), psdf.reindex_like(psdf2).sort_index(), pdf.reindex_like(pdf2).sort_index(),
psdf.reindex_like(psdf2).sort_index(),
) )
def test_melt(self): def test_melt(self):
@ -3953,16 +3961,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pdf.duplicated().sort_index(), psdf.duplicated().sort_index()) self.assert_eq(pdf.duplicated().sort_index(), psdf.duplicated().sort_index())
self.assert_eq( self.assert_eq(
pdf.duplicated(keep="last").sort_index(), psdf.duplicated(keep="last").sort_index(), pdf.duplicated(keep="last").sort_index(),
psdf.duplicated(keep="last").sort_index(),
) )
self.assert_eq( self.assert_eq(
pdf.duplicated(keep=False).sort_index(), psdf.duplicated(keep=False).sort_index(), pdf.duplicated(keep=False).sort_index(),
psdf.duplicated(keep=False).sort_index(),
) )
self.assert_eq( self.assert_eq(
pdf.duplicated(subset="b").sort_index(), psdf.duplicated(subset="b").sort_index(), pdf.duplicated(subset="b").sort_index(),
psdf.duplicated(subset="b").sort_index(),
) )
self.assert_eq( self.assert_eq(
pdf.duplicated(subset=["b"]).sort_index(), psdf.duplicated(subset=["b"]).sort_index(), pdf.duplicated(subset=["b"]).sort_index(),
psdf.duplicated(subset=["b"]).sort_index(),
) )
with self.assertRaisesRegex(ValueError, "'keep' only supports 'first', 'last' and False"): with self.assertRaisesRegex(ValueError, "'keep' only supports 'first', 'last' and False"):
psdf.duplicated(keep="false") psdf.duplicated(keep="false")
@ -4009,7 +4021,8 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pdf.duplicated().sort_index(), psdf.duplicated().sort_index()) self.assert_eq(pdf.duplicated().sort_index(), psdf.duplicated().sort_index())
self.assert_eq( self.assert_eq(
pdf.duplicated(subset=10).sort_index(), psdf.duplicated(subset=10).sort_index(), pdf.duplicated(subset=10).sort_index(),
psdf.duplicated(subset=10).sort_index(),
) )
def test_ffill(self): def test_ffill(self):
@ -4651,17 +4664,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index() psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index()
) )
self.assert_eq( self.assert_eq(
psdf.take(range(1, 3), axis=1).sort_index(), pdf.take(range(1, 3), axis=1).sort_index(), psdf.take(range(1, 3), axis=1).sort_index(),
pdf.take(range(1, 3), axis=1).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.take(range(-1, -3), axis=1).sort_index(), psdf.take(range(-1, -3), axis=1).sort_index(),
pdf.take(range(-1, -3), axis=1).sort_index(), pdf.take(range(-1, -3), axis=1).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(), psdf.take([2, 1], axis=1).sort_index(),
pdf.take([2, 1], axis=1).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index(), psdf.take([-1, -2], axis=1).sort_index(),
pdf.take([-1, -2], axis=1).sort_index(),
) )
# MultiIndex columns # MultiIndex columns
@ -4695,17 +4711,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index() psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index()
) )
self.assert_eq( self.assert_eq(
psdf.take(range(1, 3), axis=1).sort_index(), pdf.take(range(1, 3), axis=1).sort_index(), psdf.take(range(1, 3), axis=1).sort_index(),
pdf.take(range(1, 3), axis=1).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.take(range(-1, -3), axis=1).sort_index(), psdf.take(range(-1, -3), axis=1).sort_index(),
pdf.take(range(-1, -3), axis=1).sort_index(), pdf.take(range(-1, -3), axis=1).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(), psdf.take([2, 1], axis=1).sort_index(),
pdf.take([2, 1], axis=1).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index(), psdf.take([-1, -2], axis=1).sort_index(),
pdf.take([-1, -2], axis=1).sort_index(),
) )
# Checking the type of indices. # Checking the type of indices.
@ -5524,35 +5543,40 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
psdf.at_time("0:20") psdf.at_time("0:20")
self.assert_eq( self.assert_eq(
pdf.at_time("0:20").sort_index(), psdf.at_time("0:20").sort_index(), pdf.at_time("0:20").sort_index(),
psdf.at_time("0:20").sort_index(),
) )
# Index name is 'ts' # Index name is 'ts'
pdf.index.name = "ts" pdf.index.name = "ts"
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
pdf.at_time("0:20").sort_index(), psdf.at_time("0:20").sort_index(), pdf.at_time("0:20").sort_index(),
psdf.at_time("0:20").sort_index(),
) )
# Index name is 'ts', column label is 'index' # Index name is 'ts', column label is 'index'
pdf.columns = pd.Index(["index"]) pdf.columns = pd.Index(["index"])
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
pdf.at_time("0:40").sort_index(), psdf.at_time("0:40").sort_index(), pdf.at_time("0:40").sort_index(),
psdf.at_time("0:40").sort_index(),
) )
# Both index name and column label are 'index' # Both index name and column label are 'index'
pdf.index.name = "index" pdf.index.name = "index"
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
pdf.at_time("0:40").sort_index(), psdf.at_time("0:40").sort_index(), pdf.at_time("0:40").sort_index(),
psdf.at_time("0:40").sort_index(),
) )
# Index name is 'index', column label is ('X', 'A') # Index name is 'index', column label is ('X', 'A')
pdf.columns = pd.MultiIndex.from_arrays([["X"], ["A"]]) pdf.columns = pd.MultiIndex.from_arrays([["X"], ["A"]])
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
pdf.at_time("0:40").sort_index(), psdf.at_time("0:40").sort_index(), pdf.at_time("0:40").sort_index(),
psdf.at_time("0:40").sort_index(),
) )
with self.assertRaisesRegex(NotImplementedError, "'asof' argument is not supported"): with self.assertRaisesRegex(NotImplementedError, "'asof' argument is not supported"):

View file

@ -310,7 +310,8 @@ class DataFrameSparkIOTest(PandasOnSparkTestCase, TestUtils):
self.assert_eq(psdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"]) self.assert_eq(psdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
self.assert_eq( self.assert_eq(
ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"), pdfs1["Sheet_name_2"], ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"),
pdfs1["Sheet_name_2"],
) )
for sheet_name in sheet_names: for sheet_name in sheet_names:

View file

@ -96,7 +96,8 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
psdf = ps.DataFrame({"a": [1, 2, 3, 2], "b": [4.0, 2.0, 3.0, 1.0]}, index=idx) psdf = ps.DataFrame({"a": [1, 2, 3, 2], "b": [4.0, 2.0, 3.0, 1.0]}, index=idx)
psdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) psdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
expected_result = pd.DataFrame( expected_result = pd.DataFrame(
{("a", "x"): [None, 2.0, 3.0, 4.0], ("a", "y"): [None, 2.0, 3.0, 4.0]}, index=idx, {("a", "x"): [None, 2.0, 3.0, 4.0], ("a", "y"): [None, 2.0, 3.0, 4.0]},
index=idx,
) )
self.assert_eq(psdf.expanding(2).count().sort_index(), expected_result.sort_index()) self.assert_eq(psdf.expanding(2).count().sort_index(), expected_result.sort_index())

View file

@ -934,7 +934,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
expected = ps.DataFrame({"b": [2, 2]}, index=pd.Index([0, 1], name="a")) expected = ps.DataFrame({"b": [2, 2]}, index=pd.Index([0, 1], name="a"))
self.assert_eq(psdf.groupby("a").nunique().sort_index(), expected) self.assert_eq(psdf.groupby("a").nunique().sort_index(), expected)
self.assert_eq( self.assert_eq(
psdf.groupby("a").nunique(dropna=False).sort_index(), expected, psdf.groupby("a").nunique(dropna=False).sort_index(),
expected,
) )
else: else:
self.assert_eq( self.assert_eq(
@ -968,10 +969,12 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
if LooseVersion(pd.__version__) < LooseVersion("1.1.0"): if LooseVersion(pd.__version__) < LooseVersion("1.1.0"):
expected = ps.DataFrame({("y", "b"): [2, 2]}, index=pd.Index([0, 1], name=("x", "a"))) expected = ps.DataFrame({("y", "b"): [2, 2]}, index=pd.Index([0, 1], name=("x", "a")))
self.assert_eq( self.assert_eq(
psdf.groupby(("x", "a")).nunique().sort_index(), expected, psdf.groupby(("x", "a")).nunique().sort_index(),
expected,
) )
self.assert_eq( self.assert_eq(
psdf.groupby(("x", "a")).nunique(dropna=False).sort_index(), expected, psdf.groupby(("x", "a")).nunique(dropna=False).sort_index(),
expected,
) )
else: else:
self.assert_eq( self.assert_eq(
@ -1785,7 +1788,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
pdf.groupby("A")[["B"]].bfill().sort_index(), pdf.groupby("A")[["B"]].bfill().sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.groupby("A")["B"].bfill().sort_index(), pdf.groupby("A")["B"].bfill().sort_index(), psdf.groupby("A")["B"].bfill().sort_index(),
pdf.groupby("A")["B"].bfill().sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.groupby("A")["B"].bfill()[idx[6]], pdf.groupby("A")["B"].bfill()[idx[6]] psdf.groupby("A")["B"].bfill()[idx[6]], pdf.groupby("A")["B"].bfill()[idx[6]]
@ -1893,7 +1897,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
pdf.groupby("b").apply(lambda x: x + x.min()).sort_index(), pdf.groupby("b").apply(lambda x: x + x.min()).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.groupby("b").apply(len).sort_index(), pdf.groupby("b").apply(len).sort_index(), psdf.groupby("b").apply(len).sort_index(),
pdf.groupby("b").apply(len).sort_index(),
) )
self.assert_eq( self.assert_eq(
psdf.groupby("b")["a"] psdf.groupby("b")["a"]
@ -2556,7 +2561,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
psdf.groupby("class").get_group("bird"), pdf.groupby("class").get_group("bird"), psdf.groupby("class").get_group("bird"),
pdf.groupby("class").get_group("bird"),
) )
self.assert_eq( self.assert_eq(
psdf.groupby("class")["name"].get_group("mammal"), psdf.groupby("class")["name"].get_group("mammal"),
@ -2583,7 +2589,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
(pdf.max_speed + 1).groupby(pdf["class"]).get_group("mammal"), (pdf.max_speed + 1).groupby(pdf["class"]).get_group("mammal"),
) )
self.assert_eq( self.assert_eq(
psdf.groupby("max_speed").get_group(80.5), pdf.groupby("max_speed").get_group(80.5), psdf.groupby("max_speed").get_group(80.5),
pdf.groupby("max_speed").get_group(80.5),
) )
self.assertRaises(KeyError, lambda: psdf.groupby("class").get_group("fish")) self.assertRaises(KeyError, lambda: psdf.groupby("class").get_group("fish"))
@ -2646,7 +2653,8 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion", "mammal")), lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion", "mammal")),
) )
self.assertRaises( self.assertRaises(
ValueError, lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion",)), ValueError,
lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion",)),
) )
self.assertRaises( self.assertRaises(
ValueError, lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("mammal",)) ValueError, lambda: psdf.groupby([("B", "class"), ("A", "name")]).get_group(("mammal",))

View file

@ -266,7 +266,10 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
objs = [ objs = [
([psdf1.A, psdf1.A.rename("B")], [pdf1.A, pdf1.A.rename("B")]), ([psdf1.A, psdf1.A.rename("B")], [pdf1.A, pdf1.A.rename("B")]),
([psdf3[("X", "A")], psdf3[("X", "B")]], [pdf3[("X", "A")], pdf3[("X", "B")]],), (
[psdf3[("X", "A")], psdf3[("X", "B")]],
[pdf3[("X", "A")], pdf3[("X", "B")]],
),
( (
[psdf3[("X", "A")], psdf3[("X", "B")].rename("ABC")], [psdf3[("X", "A")], psdf3[("X", "B")].rename("ABC")],
[pdf3[("X", "A")], pdf3[("X", "B")].rename("ABC")], [pdf3[("X", "A")], pdf3[("X", "B")].rename("ABC")],

View file

@ -728,7 +728,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
psser1 = ps.from_pandas(pser1) psser1 = ps.from_pandas(pser1)
psser2 = ps.from_pandas(pser2) psser2 = ps.from_pandas(pser2)
self.assert_eq( self.assert_eq(
pser1.compare(pser2).sort_index(), psser1.compare(psser2).sort_index(), pser1.compare(pser2).sort_index(),
psser1.compare(psser2).sort_index(),
) )
# `keep_shape=True` # `keep_shape=True`
@ -757,7 +758,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
psser1 = ps.from_pandas(pser1) psser1 = ps.from_pandas(pser1)
psser2 = ps.from_pandas(pser2) psser2 = ps.from_pandas(pser2)
self.assert_eq( self.assert_eq(
pser1.compare(pser2).sort_index(), psser1.compare(psser2).sort_index(), pser1.compare(pser2).sort_index(),
psser1.compare(psser2).sort_index(),
) )
# `keep_shape=True` with MultiIndex # `keep_shape=True` with MultiIndex
@ -790,14 +792,16 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"], columns=["self", "other"],
) )
self.assert_eq( self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True).sort_index(), expected,
psser1.compare(psser2, keep_shape=True).sort_index(),
) )
# `keep_equal=True` # `keep_equal=True`
expected = ps.DataFrame( expected = ps.DataFrame(
[["b", "a"], ["g", None], [None, "h"]], index=[0, 3, 4], columns=["self", "other"] [["b", "a"], ["g", None], [None, "h"]], index=[0, 3, 4], columns=["self", "other"]
) )
self.assert_eq( self.assert_eq(
expected, psser1.compare(psser2, keep_equal=True).sort_index(), expected,
psser1.compare(psser2, keep_equal=True).sort_index(),
) )
# `keep_shape=True` and `keep_equal=True` # `keep_shape=True` and `keep_equal=True`
expected = ps.DataFrame( expected = ps.DataFrame(
@ -806,7 +810,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"], columns=["self", "other"],
) )
self.assert_eq( self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(), expected,
psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(),
) )
# MultiIndex # MultiIndex
@ -838,7 +843,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"], columns=["self", "other"],
) )
self.assert_eq( self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True).sort_index(), expected,
psser1.compare(psser2, keep_shape=True).sort_index(),
) )
# `keep_equal=True` # `keep_equal=True`
expected = ps.DataFrame( expected = ps.DataFrame(
@ -847,7 +853,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"], columns=["self", "other"],
) )
self.assert_eq( self.assert_eq(
expected, psser1.compare(psser2, keep_equal=True).sort_index(), expected,
psser1.compare(psser2, keep_equal=True).sort_index(),
) )
# `keep_shape=True` and `keep_equal=True` # `keep_shape=True` and `keep_equal=True`
expected = ps.DataFrame( expected = ps.DataFrame(
@ -858,15 +865,22 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
columns=["self", "other"], columns=["self", "other"],
) )
self.assert_eq( self.assert_eq(
expected, psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(), expected,
psser1.compare(psser2, keep_shape=True, keep_equal=True).sort_index(),
) )
# Different Index # Different Index
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Can only compare identically-labeled Series objects" ValueError, "Can only compare identically-labeled Series objects"
): ):
psser1 = ps.Series([1, 2, 3, 4, 5], index=pd.Index([1, 2, 3, 4, 5]),) psser1 = ps.Series(
psser2 = ps.Series([2, 2, 3, 4, 1], index=pd.Index([5, 4, 3, 2, 1]),) [1, 2, 3, 4, 5],
index=pd.Index([1, 2, 3, 4, 5]),
)
psser2 = ps.Series(
[2, 2, 3, 4, 1],
index=pd.Index([5, 4, 3, 2, 1]),
)
psser1.compare(psser2) psser1.compare(psser2)
# Different MultiIndex # Different MultiIndex
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -1153,7 +1167,8 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(psdf, pdf) self.assert_eq(psdf, pdf)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "shape mismatch", ValueError,
"shape mismatch",
): ):
psdf.iloc[[1, 2], [1]] = -another_psdf.max_speed psdf.iloc[[1, 2], [1]] = -another_psdf.max_speed

View file

@ -135,7 +135,8 @@ class OpsOnDiffFramesGroupByTest(PandasOnSparkTestCase, SQLTestUtils):
) )
self.assert_eq( self.assert_eq(
psdf1.B.groupby(psdf2.A).sum().sort_index(), pdf1.B.groupby(pdf2.A).sum().sort_index(), psdf1.B.groupby(psdf2.A).sum().sort_index(),
pdf1.B.groupby(pdf2.A).sum().sort_index(),
) )
self.assert_eq( self.assert_eq(
(psdf1.B + 1).groupby(psdf2.A).sum().sort_index(), (psdf1.B + 1).groupby(psdf2.A).sum().sort_index(),

View file

@ -264,7 +264,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser = ps.from_pandas(pser) psser = ps.from_pandas(pser)
self.assert_eq( self.assert_eq(
pser.rename_axis("index2").sort_index(), psser.rename_axis("index2").sort_index(), pser.rename_axis("index2").sort_index(),
psser.rename_axis("index2").sort_index(),
) )
self.assert_eq( self.assert_eq(
@ -463,7 +464,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pser, psser) self.assert_eq(pser, psser)
self.assert_eq( self.assert_eq(
pser.reindex(["A", "B"]).sort_index(), psser.reindex(["A", "B"]).sort_index(), pser.reindex(["A", "B"]).sort_index(),
psser.reindex(["A", "B"]).sort_index(),
) )
self.assert_eq( self.assert_eq(
@ -491,7 +493,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser2 = ps.from_pandas(pser2) psser2 = ps.from_pandas(pser2)
self.assert_eq( self.assert_eq(
pser.reindex_like(pser2).sort_index(), psser.reindex_like(psser2).sort_index(), pser.reindex_like(pser2).sort_index(),
psser.reindex_like(psser2).sort_index(),
) )
self.assert_eq( self.assert_eq(
@ -507,7 +510,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser2 = ps.from_pandas(pser2) psser2 = ps.from_pandas(pser2)
self.assert_eq( self.assert_eq(
pser.reindex_like(pser2).sort_index(), psser.reindex_like(psser2).sort_index(), pser.reindex_like(pser2).sort_index(),
psser.reindex_like(psser2).sort_index(),
) )
self.assertRaises(TypeError, lambda: psser.reindex_like(index2)) self.assertRaises(TypeError, lambda: psser.reindex_like(index2))
@ -521,7 +525,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser = ps.from_pandas(pser) psser = ps.from_pandas(pser)
self.assert_eq( self.assert_eq(
pser.reindex_like(pser2).sort_index(), psser.reindex_like(psser2).sort_index(), pser.reindex_like(pser2).sort_index(),
psser.reindex_like(psser2).sort_index(),
) )
# Reindexing with DataFrame # Reindexing with DataFrame
@ -532,7 +537,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psdf = ps.from_pandas(pdf) psdf = ps.from_pandas(pdf)
self.assert_eq( self.assert_eq(
pser.reindex_like(pdf).sort_index(), psser.reindex_like(psdf).sort_index(), pser.reindex_like(pdf).sort_index(),
psser.reindex_like(psdf).sort_index(),
) )
def test_fillna(self): def test_fillna(self):
@ -2897,19 +2903,22 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
pser = pd.Series([1, 2, 3, 4], index=idx) pser = pd.Series([1, 2, 3, 4], index=idx)
psser = ps.from_pandas(pser) psser = ps.from_pandas(pser)
self.assert_eq( self.assert_eq(
pser.at_time("0:20").sort_index(), psser.at_time("0:20").sort_index(), pser.at_time("0:20").sort_index(),
psser.at_time("0:20").sort_index(),
) )
pser.index.name = "ts" pser.index.name = "ts"
psser = ps.from_pandas(pser) psser = ps.from_pandas(pser)
self.assert_eq( self.assert_eq(
pser.at_time("0:20").sort_index(), psser.at_time("0:20").sort_index(), pser.at_time("0:20").sort_index(),
psser.at_time("0:20").sort_index(),
) )
pser.index.name = "index" pser.index.name = "index"
psser = ps.from_pandas(pser) psser = ps.from_pandas(pser)
self.assert_eq( self.assert_eq(
pser.at_time("0:20").sort_index(), psser.at_time("0:20").sort_index(), pser.at_time("0:20").sort_index(),
psser.at_time("0:20").sort_index(),
) )

View file

@ -216,7 +216,7 @@ def as_spark_type(tpe: Union[str, type, Dtype], *, raise_error: bool = True) ->
def spark_type_to_pandas_dtype( def spark_type_to_pandas_dtype(
spark_type: types.DataType, *, use_extension_dtypes: bool = False spark_type: types.DataType, *, use_extension_dtypes: bool = False
) -> Dtype: ) -> Dtype:
""" Return the given Spark DataType to pandas dtype. """ """Return the given Spark DataType to pandas dtype."""
if use_extension_dtypes and extension_dtypes_available: if use_extension_dtypes and extension_dtypes_available:
# IntegralType # IntegralType

View file

@ -25,7 +25,7 @@ from typing import Any, Optional
def get_logger() -> Any: def get_logger() -> Any:
""" An entry point of the plug-in and return the usage logger. """ """An entry point of the plug-in and return the usage logger."""
return PandasOnSparkUsageLogger() return PandasOnSparkUsageLogger()

View file

@ -440,7 +440,7 @@ def align_diff_frames(
def is_testing() -> bool: def is_testing() -> bool:
""" Indicates whether Spark is currently running tests. """ """Indicates whether Spark is currently running tests."""
return "SPARK_TESTING" in os.environ return "SPARK_TESTING" in os.environ
@ -574,12 +574,12 @@ def lazy_property(fn: Callable[[Any], Any]) -> property:
def scol_for(sdf: spark.DataFrame, column_name: str) -> spark.Column: def scol_for(sdf: spark.DataFrame, column_name: str) -> spark.Column:
""" Return Spark Column for the given column name. """ """Return Spark Column for the given column name."""
return sdf["`{}`".format(column_name)] return sdf["`{}`".format(column_name)]
def column_labels_level(column_labels: List[Tuple]) -> int: def column_labels_level(column_labels: List[Tuple]) -> int:
""" Return the level of the column index. """ """Return the level of the column index."""
if len(column_labels) == 0: if len(column_labels) == 0:
return 1 return 1
else: else:
@ -700,7 +700,7 @@ def is_name_like_value(
def validate_axis(axis: Optional[Union[int, str]] = 0, none_axis: int = 0) -> int: def validate_axis(axis: Optional[Union[int, str]] = 0, none_axis: int = 0) -> int:
""" Check the given axis is valid. """ """Check the given axis is valid."""
# convert to numeric axis # convert to numeric axis
axis = cast( axis = cast(
Dict[Optional[Union[int, str]], int], {None: none_axis, "index": 0, "columns": 1} Dict[Optional[Union[int, str]], int], {None: none_axis, "index": 0, "columns": 1}
@ -712,7 +712,7 @@ def validate_axis(axis: Optional[Union[int, str]] = 0, none_axis: int = 0) -> in
def validate_bool_kwarg(value: Any, arg_name: str) -> Optional[bool]: def validate_bool_kwarg(value: Any, arg_name: str) -> Optional[bool]:
""" Ensures that argument passed in arg_name is of type bool. """ """Ensures that argument passed in arg_name is of type bool."""
if not (isinstance(value, bool) or value is None): if not (isinstance(value, bool) or value is None):
raise TypeError( raise TypeError(
'For argument "{}" expected type bool, received ' 'For argument "{}" expected type bool, received '
@ -722,7 +722,7 @@ def validate_bool_kwarg(value: Any, arg_name: str) -> Optional[bool]:
def validate_how(how: str) -> str: def validate_how(how: str) -> str:
""" Check the given how for join is valid. """ """Check the given how for join is valid."""
if how == "full": if how == "full":
warnings.warn( warnings.warn(
"Warning: While pandas-on-Spark will accept 'full', you should use 'outer' " "Warning: While pandas-on-Spark will accept 'full', you should use 'outer' "