[SPARK-36350][PYTHON] Move some logic related to F.nanvl to DataTypeOps

### What changes were proposed in this pull request?

Move some logic related to `F.nanvl` to `DataTypeOps`.

### Why are the changes needed?

There are several places to branch by `FloatType` or `DoubleType` to use `F.nanvl` but `DataTypeOps` should handle it.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #33582 from ueshin/issues/SPARK-36350/nan_to_null.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
(cherry picked from commit 895e3f5e2a)
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Takuya UESHIN 2021-07-30 11:19:49 -07:00
parent fee87f13d1
commit a4dcda1794
6 changed files with 112 additions and 106 deletions

View file

@ -366,5 +366,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
),
)
def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
return index_ops.copy()
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
raise TypeError("astype can not be applied to %s." % self.pretty_name)

View file

@ -326,6 +326,14 @@ class FractionalOps(NumericOps):
),
)
def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
# Special handle floating point types because Spark's count treats nan as a valid value,
# whereas pandas count doesn't include nan.
return index_ops._with_new_scol(
F.nanvl(index_ops.spark.column, SF.lit(None)),
field=index_ops._internal.data_fields[0].copy(nullable=True),
)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
@ -385,6 +393,9 @@ class DecimalOps(FractionalOps):
),
)
def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
return index_ops.copy()
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
# TODO(SPARK-36230): check index_ops.hasnans after fixing SPARK-36230
dtype, spark_type = pandas_on_spark_type(dtype)

View file

@ -642,7 +642,7 @@ class DataFrame(Frame, Generic[T]):
def _reduce_for_stat_function(
self,
sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], Column]],
sfun: Callable[["Series"], Column],
name: str,
axis: Optional[Axis] = None,
numeric_only: bool = True,
@ -664,7 +664,6 @@ class DataFrame(Frame, Generic[T]):
is mainly for pandas compatibility. Only 'DataFrame.count' uses this parameter
currently.
"""
from inspect import signature
from pyspark.pandas.series import Series, first_series
axis = validate_axis(axis)
@ -673,29 +672,19 @@ class DataFrame(Frame, Generic[T]):
exprs = [SF.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)]
new_column_labels = []
num_args = len(signature(sfun).parameters)
for label in self._internal.column_labels:
spark_column = self._internal.spark_column_for(label)
spark_type = self._internal.spark_type_for(label)
psser = self._psser_for(label)
is_numeric_or_boolean = isinstance(spark_type, (NumericType, BooleanType))
is_numeric_or_boolean = isinstance(
psser.spark.data_type, (NumericType, BooleanType)
)
keep_column = not numeric_only or is_numeric_or_boolean
if keep_column:
if num_args == 1:
# Only pass in the column if sfun accepts only one arg
scol = cast(Callable[[Column], Column], sfun)(spark_column)
else: # must be 2
assert num_args == 2
# Pass in both the column and its data type if sfun accepts two args
scol = cast(Callable[[Column, DataType], Column], sfun)(
spark_column, spark_type
)
scol = sfun(psser)
if min_count > 0:
scol = F.when(
Frame._count_expr(spark_column, spark_type) >= min_count, scol
)
scol = F.when(Frame._count_expr(psser) >= min_count, scol)
exprs.append(scol.alias(name_like_string(label)))
new_column_labels.append(label)
@ -8485,16 +8474,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
exprs = []
column_labels = []
for label in self._internal.column_labels:
scol = self._internal.spark_column_for(label)
spark_type = self._internal.spark_type_for(label)
# TODO(SPARK-36350): Make this work with DataTypeOps.
if isinstance(spark_type, (FloatType, DoubleType)):
exprs.append(
F.nanvl(scol, SF.lit(None)).alias(self._internal.spark_column_name_for(label))
)
column_labels.append(label)
elif isinstance(spark_type, NumericType):
exprs.append(scol)
psser = self._psser_for(label)
if isinstance(psser.spark.data_type, NumericType):
exprs.append(psser._dtype_op.nan_to_null(psser).spark.column)
column_labels.append(label)
if len(exprs) == 0:
@ -10813,7 +10795,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if v < 0.0 or v > 1.0:
raise ValueError("percentiles should all be in the interval [0, 1].")
def quantile(spark_column: Column, spark_type: DataType) -> Column:
def quantile(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, (BooleanType, NumericType)):
return F.percentile_approx(spark_column.cast(DoubleType()), qq, accuracy)
else:
@ -10839,13 +10823,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
for label, column in zip(
self._internal.column_labels, self._internal.data_spark_column_names
):
spark_type = self._internal.spark_type_for(label)
psser = self._psser_for(label)
is_numeric_or_boolean = isinstance(spark_type, (NumericType, BooleanType))
is_numeric_or_boolean = isinstance(
psser.spark.data_type, (NumericType, BooleanType)
)
keep_column = not numeric_only or is_numeric_or_boolean
if keep_column:
percentile_col = quantile(self._internal.spark_column_for(label), spark_type)
percentile_col = quantile(psser)
percentile_cols.append(percentile_col.alias(column))
percentile_col_names.append(column)
column_labels.append(label)

View file

@ -44,9 +44,7 @@ from pandas.api.types import is_list_like
from pyspark.sql import Column, functions as F
from pyspark.sql.types import (
BooleanType,
DataType,
DoubleType,
FloatType,
IntegralType,
LongType,
NumericType,
@ -114,7 +112,7 @@ class Frame(object, metaclass=ABCMeta):
@abstractmethod
def _reduce_for_stat_function(
self,
sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], Column]],
sfun: Callable[["Series"], Column],
name: str,
axis: Optional[Axis] = None,
numeric_only: bool = True,
@ -1204,7 +1202,9 @@ class Frame(object, metaclass=ABCMeta):
if numeric_only is None and axis == 0:
numeric_only = True
def mean(spark_column: Column, spark_type: DataType) -> Column:
def mean(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -1289,7 +1289,9 @@ class Frame(object, metaclass=ABCMeta):
elif numeric_only is True and axis == 1:
numeric_only = None
def sum(spark_column: Column, spark_type: DataType) -> Column:
def sum(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -1373,7 +1375,9 @@ class Frame(object, metaclass=ABCMeta):
elif numeric_only is True and axis == 1:
numeric_only = None
def prod(spark_column: Column, spark_type: DataType) -> Column:
def prod(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
scol = F.min(F.coalesce(spark_column, SF.lit(True))).cast(LongType())
elif isinstance(spark_type, NumericType):
@ -1444,7 +1448,9 @@ class Frame(object, metaclass=ABCMeta):
if numeric_only is None and axis == 0:
numeric_only = True
def skew(spark_column: Column, spark_type: DataType) -> Column:
def skew(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -1501,7 +1507,9 @@ class Frame(object, metaclass=ABCMeta):
if numeric_only is None and axis == 0:
numeric_only = True
def kurtosis(spark_column: Column, spark_type: DataType) -> Column:
def kurtosis(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -1570,7 +1578,10 @@ class Frame(object, metaclass=ABCMeta):
numeric_only = None
return self._reduce_for_stat_function(
F.min, name="min", axis=axis, numeric_only=numeric_only
lambda psser: F.min(psser.spark.column),
name="min",
axis=axis,
numeric_only=numeric_only,
)
def max(
@ -1625,7 +1636,10 @@ class Frame(object, metaclass=ABCMeta):
numeric_only = None
return self._reduce_for_stat_function(
F.max, name="max", axis=axis, numeric_only=numeric_only
lambda psser: F.max(psser.spark.column),
name="max",
axis=axis,
numeric_only=numeric_only,
)
def count(
@ -1763,7 +1777,9 @@ class Frame(object, metaclass=ABCMeta):
if numeric_only is None and axis == 0:
numeric_only = True
def std(spark_column: Column, spark_type: DataType) -> Column:
def std(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -1842,7 +1858,9 @@ class Frame(object, metaclass=ABCMeta):
if numeric_only is None and axis == 0:
numeric_only = True
def var(spark_column: Column, spark_type: DataType) -> Column:
def var(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -1955,7 +1973,9 @@ class Frame(object, metaclass=ABCMeta):
"accuracy must be an integer; however, got [%s]" % type(accuracy).__name__
)
def median(spark_column: Column, spark_type: DataType) -> Column:
def median(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, (BooleanType, NumericType)):
return F.percentile_approx(spark_column.cast(DoubleType()), 0.5, accuracy)
else:
@ -2037,7 +2057,9 @@ class Frame(object, metaclass=ABCMeta):
if numeric_only is None and axis == 0:
numeric_only = True
def std(spark_column: Column, spark_type: DataType) -> Column:
def std(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, BooleanType):
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
@ -2051,10 +2073,8 @@ class Frame(object, metaclass=ABCMeta):
else:
return F.stddev_samp(spark_column)
def sem(spark_column: Column, spark_type: DataType) -> Column:
return std(spark_column, spark_type) / pow(
Frame._count_expr(spark_column, spark_type), 0.5
)
def sem(psser: "Series") -> Column:
return std(psser) / pow(Frame._count_expr(psser), 0.5)
return self._reduce_for_stat_function(
sem, name="sem", numeric_only=numeric_only, axis=axis, ddof=ddof
@ -3180,14 +3200,8 @@ class Frame(object, metaclass=ABCMeta):
)
@staticmethod
def _count_expr(spark_column: Column, spark_type: DataType) -> Column:
# Special handle floating point types because Spark's count treats nan as a valid value,
# whereas pandas count doesn't include nan.
# TODO(SPARK-36350): Make this work with DataTypeOps.
if isinstance(spark_type, (FloatType, DoubleType)):
return F.count(F.nanvl(spark_column, SF.lit(None)))
else:
return F.count(spark_column)
def _count_expr(psser: "Series") -> Column:
return F.count(psser._dtype_op.nan_to_null(psser).spark.column)
def _test() -> None:

View file

@ -2524,40 +2524,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
def _reduce_for_stat_function(
self, sfun: Callable[[Column], Column], only_numeric: bool
) -> FrameLike:
agg_columns = self._agg_columns
agg_columns_scols = self._agg_columns_scols
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
groupkey_scols = [s.alias(name) for s, name in zip(self._groupkeys_scols, groupkey_names)]
sdf = self._psdf._internal.spark_frame.select(groupkey_scols + agg_columns_scols)
agg_columns = [
psser
for psser in self._agg_columns
if isinstance(psser.spark.data_type, NumericType) or not only_numeric
]
data_columns = []
column_labels = []
if len(agg_columns) > 0:
stat_exprs = []
for psser in agg_columns:
spark_type = psser.spark.data_type
name = psser._internal.data_spark_column_names[0]
label = psser._column_label
scol = scol_for(sdf, name)
# TODO: we should have a function that takes dataframes and converts the numeric
# types. Converting the NaNs is used in a few places, it should be in utils.
# Special handle floating point types because Spark's count treats nan as a valid
# value, whereas pandas count doesn't include nan.
# TODO(SPARK-36350): Make this work with DataTypeOps.
if isinstance(spark_type, (FloatType, DoubleType)):
stat_exprs.append(sfun(F.nanvl(scol, SF.lit(None))).alias(name))
data_columns.append(name)
column_labels.append(label)
elif isinstance(spark_type, NumericType) or not only_numeric:
stat_exprs.append(sfun(scol).alias(name))
data_columns.append(name)
column_labels.append(label)
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
else:
sdf = sdf.select(*groupkey_names).distinct()
sdf = self._psdf._internal.spark_frame.select(
*groupkey_scols, *[psser.spark.column for psser in agg_columns]
)
internal = InternalFrame(
spark_frame=sdf,
@ -2567,12 +2545,36 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
psser._internal.data_fields[0].copy(name=name)
for psser, name in zip(self._groupkeys, groupkey_names)
],
column_labels=column_labels,
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
data_spark_columns=[
scol_for(sdf, psser._internal.data_spark_column_names[0]) for psser in agg_columns
],
column_labels=[psser._column_label for psser in agg_columns],
data_fields=[psser._internal.data_fields[0] for psser in agg_columns],
column_label_names=self._psdf._internal.column_label_names,
)
psdf = DataFrame(internal) # type: DataFrame
if len(psdf._internal.column_labels) > 0:
stat_exprs = []
for label in psdf._internal.column_labels:
psser = psdf._psser_for(label)
stat_exprs.append(
sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias(
psser._internal.data_spark_column_names[0]
)
)
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
else:
sdf = sdf.select(*groupkey_names).distinct()
internal = internal.copy(
spark_frame=sdf,
index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
data_fields=None,
)
psdf = DataFrame(internal)
if self._dropna:
psdf = DataFrame(
psdf._internal.with_new_sdf(

View file

@ -54,7 +54,6 @@ from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
from pyspark.sql.types import (
ArrayType,
BooleanType,
DataType,
DecimalType,
DoubleType,
FloatType,
@ -3453,7 +3452,9 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if q_float < 0.0 or q_float > 1.0:
raise ValueError("percentiles should all be in the interval [0, 1].")
def quantile(spark_column: Column, spark_type: DataType) -> Column:
def quantile(psser: Series) -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, (BooleanType, NumericType)):
return F.percentile_approx(spark_column.cast(DoubleType()), q_float, accuracy)
else:
@ -6186,7 +6187,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
def _reduce_for_stat_function(
self,
sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], Column]],
sfun: Callable[["Series"], Column],
name: str_type,
axis: Optional[Axis] = None,
numeric_only: bool = True,
@ -6202,26 +6203,15 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
axis : used only for sanity check because series only support index axis.
numeric_only : not used by this implementation, but passed down by stats functions
"""
from inspect import signature
axis = validate_axis(axis)
if axis == 1:
raise ValueError("Series does not support columns axis.")
num_args = len(signature(sfun).parameters)
spark_column = self.spark.column
spark_type = self.spark.data_type
if num_args == 1:
# Only pass in the column if sfun accepts only one arg
scol = cast(Callable[[Column], Column], sfun)(spark_column)
else: # must be 2
assert num_args == 2
# Pass in both the column and its data type if sfun accepts two args
scol = cast(Callable[[Column, DataType], Column], sfun)(spark_column, spark_type)
scol = sfun(self)
min_count = kwargs.get("min_count", 0)
if min_count > 0:
scol = F.when(Frame._count_expr(spark_column, spark_type) >= min_count, scol)
scol = F.when(Frame._count_expr(self) >= min_count, scol)
result = unpack_scalar(self._internal.spark_frame.select(scol))
return result if result is not None else np.nan