[SPARK-36350][PYTHON] Move some logic related to F.nanvl to DataTypeOps
### What changes were proposed in this pull request?
Move some logic related to `F.nanvl` to `DataTypeOps`.
### Why are the changes needed?
There are several places to branch by `FloatType` or `DoubleType` to use `F.nanvl` but `DataTypeOps` should handle it.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #33582 from ueshin/issues/SPARK-36350/nan_to_null.
Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
(cherry picked from commit 895e3f5e2a
)
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
fee87f13d1
commit
a4dcda1794
|
@ -366,5 +366,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
|
|||
),
|
||||
)
|
||||
|
||||
def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
|
||||
return index_ops.copy()
|
||||
|
||||
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
|
||||
raise TypeError("astype can not be applied to %s." % self.pretty_name)
|
||||
|
|
|
@ -326,6 +326,14 @@ class FractionalOps(NumericOps):
|
|||
),
|
||||
)
|
||||
|
||||
def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
|
||||
# Special handle floating point types because Spark's count treats nan as a valid value,
|
||||
# whereas pandas count doesn't include nan.
|
||||
return index_ops._with_new_scol(
|
||||
F.nanvl(index_ops.spark.column, SF.lit(None)),
|
||||
field=index_ops._internal.data_fields[0].copy(nullable=True),
|
||||
)
|
||||
|
||||
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
|
@ -385,6 +393,9 @@ class DecimalOps(FractionalOps):
|
|||
),
|
||||
)
|
||||
|
||||
def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike:
|
||||
return index_ops.copy()
|
||||
|
||||
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
|
||||
# TODO(SPARK-36230): check index_ops.hasnans after fixing SPARK-36230
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
|
|
@ -642,7 +642,7 @@ class DataFrame(Frame, Generic[T]):
|
|||
|
||||
def _reduce_for_stat_function(
|
||||
self,
|
||||
sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], Column]],
|
||||
sfun: Callable[["Series"], Column],
|
||||
name: str,
|
||||
axis: Optional[Axis] = None,
|
||||
numeric_only: bool = True,
|
||||
|
@ -664,7 +664,6 @@ class DataFrame(Frame, Generic[T]):
|
|||
is mainly for pandas compatibility. Only 'DataFrame.count' uses this parameter
|
||||
currently.
|
||||
"""
|
||||
from inspect import signature
|
||||
from pyspark.pandas.series import Series, first_series
|
||||
|
||||
axis = validate_axis(axis)
|
||||
|
@ -673,29 +672,19 @@ class DataFrame(Frame, Generic[T]):
|
|||
|
||||
exprs = [SF.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)]
|
||||
new_column_labels = []
|
||||
num_args = len(signature(sfun).parameters)
|
||||
for label in self._internal.column_labels:
|
||||
spark_column = self._internal.spark_column_for(label)
|
||||
spark_type = self._internal.spark_type_for(label)
|
||||
psser = self._psser_for(label)
|
||||
|
||||
is_numeric_or_boolean = isinstance(spark_type, (NumericType, BooleanType))
|
||||
is_numeric_or_boolean = isinstance(
|
||||
psser.spark.data_type, (NumericType, BooleanType)
|
||||
)
|
||||
keep_column = not numeric_only or is_numeric_or_boolean
|
||||
|
||||
if keep_column:
|
||||
if num_args == 1:
|
||||
# Only pass in the column if sfun accepts only one arg
|
||||
scol = cast(Callable[[Column], Column], sfun)(spark_column)
|
||||
else: # must be 2
|
||||
assert num_args == 2
|
||||
# Pass in both the column and its data type if sfun accepts two args
|
||||
scol = cast(Callable[[Column, DataType], Column], sfun)(
|
||||
spark_column, spark_type
|
||||
)
|
||||
scol = sfun(psser)
|
||||
|
||||
if min_count > 0:
|
||||
scol = F.when(
|
||||
Frame._count_expr(spark_column, spark_type) >= min_count, scol
|
||||
)
|
||||
scol = F.when(Frame._count_expr(psser) >= min_count, scol)
|
||||
|
||||
exprs.append(scol.alias(name_like_string(label)))
|
||||
new_column_labels.append(label)
|
||||
|
@ -8485,16 +8474,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
exprs = []
|
||||
column_labels = []
|
||||
for label in self._internal.column_labels:
|
||||
scol = self._internal.spark_column_for(label)
|
||||
spark_type = self._internal.spark_type_for(label)
|
||||
# TODO(SPARK-36350): Make this work with DataTypeOps.
|
||||
if isinstance(spark_type, (FloatType, DoubleType)):
|
||||
exprs.append(
|
||||
F.nanvl(scol, SF.lit(None)).alias(self._internal.spark_column_name_for(label))
|
||||
)
|
||||
column_labels.append(label)
|
||||
elif isinstance(spark_type, NumericType):
|
||||
exprs.append(scol)
|
||||
psser = self._psser_for(label)
|
||||
if isinstance(psser.spark.data_type, NumericType):
|
||||
exprs.append(psser._dtype_op.nan_to_null(psser).spark.column)
|
||||
column_labels.append(label)
|
||||
|
||||
if len(exprs) == 0:
|
||||
|
@ -10813,7 +10795,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
if v < 0.0 or v > 1.0:
|
||||
raise ValueError("percentiles should all be in the interval [0, 1].")
|
||||
|
||||
def quantile(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def quantile(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, (BooleanType, NumericType)):
|
||||
return F.percentile_approx(spark_column.cast(DoubleType()), qq, accuracy)
|
||||
else:
|
||||
|
@ -10839,13 +10823,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
for label, column in zip(
|
||||
self._internal.column_labels, self._internal.data_spark_column_names
|
||||
):
|
||||
spark_type = self._internal.spark_type_for(label)
|
||||
psser = self._psser_for(label)
|
||||
|
||||
is_numeric_or_boolean = isinstance(spark_type, (NumericType, BooleanType))
|
||||
is_numeric_or_boolean = isinstance(
|
||||
psser.spark.data_type, (NumericType, BooleanType)
|
||||
)
|
||||
keep_column = not numeric_only or is_numeric_or_boolean
|
||||
|
||||
if keep_column:
|
||||
percentile_col = quantile(self._internal.spark_column_for(label), spark_type)
|
||||
percentile_col = quantile(psser)
|
||||
percentile_cols.append(percentile_col.alias(column))
|
||||
percentile_col_names.append(column)
|
||||
column_labels.append(label)
|
||||
|
|
|
@ -44,9 +44,7 @@ from pandas.api.types import is_list_like
|
|||
from pyspark.sql import Column, functions as F
|
||||
from pyspark.sql.types import (
|
||||
BooleanType,
|
||||
DataType,
|
||||
DoubleType,
|
||||
FloatType,
|
||||
IntegralType,
|
||||
LongType,
|
||||
NumericType,
|
||||
|
@ -114,7 +112,7 @@ class Frame(object, metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def _reduce_for_stat_function(
|
||||
self,
|
||||
sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], Column]],
|
||||
sfun: Callable[["Series"], Column],
|
||||
name: str,
|
||||
axis: Optional[Axis] = None,
|
||||
numeric_only: bool = True,
|
||||
|
@ -1204,7 +1202,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
if numeric_only is None and axis == 0:
|
||||
numeric_only = True
|
||||
|
||||
def mean(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def mean(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -1289,7 +1289,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
elif numeric_only is True and axis == 1:
|
||||
numeric_only = None
|
||||
|
||||
def sum(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def sum(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -1373,7 +1375,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
elif numeric_only is True and axis == 1:
|
||||
numeric_only = None
|
||||
|
||||
def prod(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def prod(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
scol = F.min(F.coalesce(spark_column, SF.lit(True))).cast(LongType())
|
||||
elif isinstance(spark_type, NumericType):
|
||||
|
@ -1444,7 +1448,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
if numeric_only is None and axis == 0:
|
||||
numeric_only = True
|
||||
|
||||
def skew(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def skew(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -1501,7 +1507,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
if numeric_only is None and axis == 0:
|
||||
numeric_only = True
|
||||
|
||||
def kurtosis(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def kurtosis(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -1570,7 +1578,10 @@ class Frame(object, metaclass=ABCMeta):
|
|||
numeric_only = None
|
||||
|
||||
return self._reduce_for_stat_function(
|
||||
F.min, name="min", axis=axis, numeric_only=numeric_only
|
||||
lambda psser: F.min(psser.spark.column),
|
||||
name="min",
|
||||
axis=axis,
|
||||
numeric_only=numeric_only,
|
||||
)
|
||||
|
||||
def max(
|
||||
|
@ -1625,7 +1636,10 @@ class Frame(object, metaclass=ABCMeta):
|
|||
numeric_only = None
|
||||
|
||||
return self._reduce_for_stat_function(
|
||||
F.max, name="max", axis=axis, numeric_only=numeric_only
|
||||
lambda psser: F.max(psser.spark.column),
|
||||
name="max",
|
||||
axis=axis,
|
||||
numeric_only=numeric_only,
|
||||
)
|
||||
|
||||
def count(
|
||||
|
@ -1763,7 +1777,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
if numeric_only is None and axis == 0:
|
||||
numeric_only = True
|
||||
|
||||
def std(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def std(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -1842,7 +1858,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
if numeric_only is None and axis == 0:
|
||||
numeric_only = True
|
||||
|
||||
def var(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def var(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -1955,7 +1973,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
"accuracy must be an integer; however, got [%s]" % type(accuracy).__name__
|
||||
)
|
||||
|
||||
def median(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def median(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, (BooleanType, NumericType)):
|
||||
return F.percentile_approx(spark_column.cast(DoubleType()), 0.5, accuracy)
|
||||
else:
|
||||
|
@ -2037,7 +2057,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
if numeric_only is None and axis == 0:
|
||||
numeric_only = True
|
||||
|
||||
def std(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def std(psser: "Series") -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, BooleanType):
|
||||
spark_column = spark_column.cast(LongType())
|
||||
elif not isinstance(spark_type, NumericType):
|
||||
|
@ -2051,10 +2073,8 @@ class Frame(object, metaclass=ABCMeta):
|
|||
else:
|
||||
return F.stddev_samp(spark_column)
|
||||
|
||||
def sem(spark_column: Column, spark_type: DataType) -> Column:
|
||||
return std(spark_column, spark_type) / pow(
|
||||
Frame._count_expr(spark_column, spark_type), 0.5
|
||||
)
|
||||
def sem(psser: "Series") -> Column:
|
||||
return std(psser) / pow(Frame._count_expr(psser), 0.5)
|
||||
|
||||
return self._reduce_for_stat_function(
|
||||
sem, name="sem", numeric_only=numeric_only, axis=axis, ddof=ddof
|
||||
|
@ -3180,14 +3200,8 @@ class Frame(object, metaclass=ABCMeta):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _count_expr(spark_column: Column, spark_type: DataType) -> Column:
|
||||
# Special handle floating point types because Spark's count treats nan as a valid value,
|
||||
# whereas pandas count doesn't include nan.
|
||||
# TODO(SPARK-36350): Make this work with DataTypeOps.
|
||||
if isinstance(spark_type, (FloatType, DoubleType)):
|
||||
return F.count(F.nanvl(spark_column, SF.lit(None)))
|
||||
else:
|
||||
return F.count(spark_column)
|
||||
def _count_expr(psser: "Series") -> Column:
|
||||
return F.count(psser._dtype_op.nan_to_null(psser).spark.column)
|
||||
|
||||
|
||||
def _test() -> None:
|
||||
|
|
|
@ -2524,40 +2524,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
|
|||
def _reduce_for_stat_function(
|
||||
self, sfun: Callable[[Column], Column], only_numeric: bool
|
||||
) -> FrameLike:
|
||||
agg_columns = self._agg_columns
|
||||
agg_columns_scols = self._agg_columns_scols
|
||||
|
||||
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
|
||||
groupkey_scols = [s.alias(name) for s, name in zip(self._groupkeys_scols, groupkey_names)]
|
||||
|
||||
sdf = self._psdf._internal.spark_frame.select(groupkey_scols + agg_columns_scols)
|
||||
agg_columns = [
|
||||
psser
|
||||
for psser in self._agg_columns
|
||||
if isinstance(psser.spark.data_type, NumericType) or not only_numeric
|
||||
]
|
||||
|
||||
data_columns = []
|
||||
column_labels = []
|
||||
if len(agg_columns) > 0:
|
||||
stat_exprs = []
|
||||
for psser in agg_columns:
|
||||
spark_type = psser.spark.data_type
|
||||
name = psser._internal.data_spark_column_names[0]
|
||||
label = psser._column_label
|
||||
scol = scol_for(sdf, name)
|
||||
# TODO: we should have a function that takes dataframes and converts the numeric
|
||||
# types. Converting the NaNs is used in a few places, it should be in utils.
|
||||
# Special handle floating point types because Spark's count treats nan as a valid
|
||||
# value, whereas pandas count doesn't include nan.
|
||||
|
||||
# TODO(SPARK-36350): Make this work with DataTypeOps.
|
||||
if isinstance(spark_type, (FloatType, DoubleType)):
|
||||
stat_exprs.append(sfun(F.nanvl(scol, SF.lit(None))).alias(name))
|
||||
data_columns.append(name)
|
||||
column_labels.append(label)
|
||||
elif isinstance(spark_type, NumericType) or not only_numeric:
|
||||
stat_exprs.append(sfun(scol).alias(name))
|
||||
data_columns.append(name)
|
||||
column_labels.append(label)
|
||||
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
|
||||
else:
|
||||
sdf = sdf.select(*groupkey_names).distinct()
|
||||
sdf = self._psdf._internal.spark_frame.select(
|
||||
*groupkey_scols, *[psser.spark.column for psser in agg_columns]
|
||||
)
|
||||
|
||||
internal = InternalFrame(
|
||||
spark_frame=sdf,
|
||||
|
@ -2567,12 +2545,36 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
|
|||
psser._internal.data_fields[0].copy(name=name)
|
||||
for psser, name in zip(self._groupkeys, groupkey_names)
|
||||
],
|
||||
column_labels=column_labels,
|
||||
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
|
||||
data_spark_columns=[
|
||||
scol_for(sdf, psser._internal.data_spark_column_names[0]) for psser in agg_columns
|
||||
],
|
||||
column_labels=[psser._column_label for psser in agg_columns],
|
||||
data_fields=[psser._internal.data_fields[0] for psser in agg_columns],
|
||||
column_label_names=self._psdf._internal.column_label_names,
|
||||
)
|
||||
psdf = DataFrame(internal) # type: DataFrame
|
||||
|
||||
if len(psdf._internal.column_labels) > 0:
|
||||
stat_exprs = []
|
||||
for label in psdf._internal.column_labels:
|
||||
psser = psdf._psser_for(label)
|
||||
stat_exprs.append(
|
||||
sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias(
|
||||
psser._internal.data_spark_column_names[0]
|
||||
)
|
||||
)
|
||||
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
|
||||
else:
|
||||
sdf = sdf.select(*groupkey_names).distinct()
|
||||
|
||||
internal = internal.copy(
|
||||
spark_frame=sdf,
|
||||
index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
|
||||
data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
|
||||
data_fields=None,
|
||||
)
|
||||
psdf = DataFrame(internal)
|
||||
|
||||
if self._dropna:
|
||||
psdf = DataFrame(
|
||||
psdf._internal.with_new_sdf(
|
||||
|
|
|
@ -54,7 +54,6 @@ from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
|
|||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
BooleanType,
|
||||
DataType,
|
||||
DecimalType,
|
||||
DoubleType,
|
||||
FloatType,
|
||||
|
@ -3453,7 +3452,9 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
if q_float < 0.0 or q_float > 1.0:
|
||||
raise ValueError("percentiles should all be in the interval [0, 1].")
|
||||
|
||||
def quantile(spark_column: Column, spark_type: DataType) -> Column:
|
||||
def quantile(psser: Series) -> Column:
|
||||
spark_type = psser.spark.data_type
|
||||
spark_column = psser.spark.column
|
||||
if isinstance(spark_type, (BooleanType, NumericType)):
|
||||
return F.percentile_approx(spark_column.cast(DoubleType()), q_float, accuracy)
|
||||
else:
|
||||
|
@ -6186,7 +6187,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
|
||||
def _reduce_for_stat_function(
|
||||
self,
|
||||
sfun: Union[Callable[[Column], Column], Callable[[Column, DataType], Column]],
|
||||
sfun: Callable[["Series"], Column],
|
||||
name: str_type,
|
||||
axis: Optional[Axis] = None,
|
||||
numeric_only: bool = True,
|
||||
|
@ -6202,26 +6203,15 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
axis : used only for sanity check because series only support index axis.
|
||||
numeric_only : not used by this implementation, but passed down by stats functions
|
||||
"""
|
||||
from inspect import signature
|
||||
|
||||
axis = validate_axis(axis)
|
||||
if axis == 1:
|
||||
raise ValueError("Series does not support columns axis.")
|
||||
num_args = len(signature(sfun).parameters)
|
||||
spark_column = self.spark.column
|
||||
spark_type = self.spark.data_type
|
||||
|
||||
if num_args == 1:
|
||||
# Only pass in the column if sfun accepts only one arg
|
||||
scol = cast(Callable[[Column], Column], sfun)(spark_column)
|
||||
else: # must be 2
|
||||
assert num_args == 2
|
||||
# Pass in both the column and its data type if sfun accepts two args
|
||||
scol = cast(Callable[[Column, DataType], Column], sfun)(spark_column, spark_type)
|
||||
scol = sfun(self)
|
||||
|
||||
min_count = kwargs.get("min_count", 0)
|
||||
if min_count > 0:
|
||||
scol = F.when(Frame._count_expr(spark_column, spark_type) >= min_count, scol)
|
||||
scol = F.when(Frame._count_expr(self) >= min_count, scol)
|
||||
|
||||
result = unpack_scalar(self._internal.spark_frame.select(scol))
|
||||
return result if result is not None else np.nan
|
||||
|
|
Loading…
Reference in a new issue