[SPARK-36103][PYTHON] Manage InternalField in DataTypeOps.invert

### What changes were proposed in this pull request?

Properly set `InternalField` for `DataTypeOps.invert`.

### Why are the changes needed?

The spark data type and nullability must be the same as the original when `DataTypeOps.invert`.
We should manage `InternalField` for this case.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #33306 from ueshin/issues/SPARK-36103/invert.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-07-13 09:22:27 +09:00 committed by Hyukjin Kwon
parent 92bf83ed0a
commit e2021daafb
3 changed files with 9 additions and 23 deletions

View file

@ -16,7 +16,7 @@
# #
import numbers import numbers
from typing import cast, Any, Union from typing import Any, Union
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype from pandas.api.types import CategoricalDtype
@ -281,29 +281,19 @@ class BooleanOps(DataTypeOps):
return operand return operand
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__lt__)(left, right) return column_op(Column.__lt__)(left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__le__)(left, right) return column_op(Column.__le__)(left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__ge__)(left, right) return column_op(Column.__ge__)(left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__gt__)(left, right) return column_op(Column.__gt__)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike: def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
from pyspark.pandas.base import column_op return operand._with_new_scol(~operand.spark.column, field=operand._internal.data_fields[0])
return cast(IndexOpsLike, column_op(Column.__invert__)(operand))
class BooleanExtensionOps(BooleanOps): class BooleanExtensionOps(BooleanOps):

View file

@ -122,9 +122,6 @@ class NumericOps(DataTypeOps):
right = transform_boolean_operand_to_numeric(right) right = transform_boolean_operand_to_numeric(right)
return column_op(rmod)(left, right) return column_op(rmod)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
return cast(IndexOpsLike, column_op(F.bitwise_not)(operand))
def neg(self, operand: IndexOpsLike) -> IndexOpsLike: def neg(self, operand: IndexOpsLike) -> IndexOpsLike:
return cast(IndexOpsLike, column_op(Column.__neg__)(operand)) return cast(IndexOpsLike, column_op(Column.__neg__)(operand))
@ -214,6 +211,11 @@ class IntegralOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right) return numpy_column_op(rfloordiv)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
return operand._with_new_scol(
F.bitwise_not(operand.spark.column), field=operand._internal.data_fields[0]
)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype) dtype, spark_type = pandas_on_spark_type(dtype)
@ -304,9 +306,6 @@ class FractionalOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right) return numpy_column_op(rfloordiv)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
def isnull(self, index_ops: IndexOpsLike) -> IndexOpsLike: def isnull(self, index_ops: IndexOpsLike) -> IndexOpsLike:
return index_ops._with_new_scol( return index_ops._with_new_scol(
index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column), index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column),
@ -348,9 +347,6 @@ class DecimalOps(FractionalOps):
def pretty_name(self) -> str: def pretty_name(self) -> str:
return "decimal" return "decimal"
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
raise TypeError("< can not be applied to %s." % self.pretty_name) raise TypeError("< can not be applied to %s." % self.pretty_name)

View file

@ -30,7 +30,7 @@ from pyspark.pandas.typedef.typehints import (
extension_dtypes_available, extension_dtypes_available,
extension_float_dtypes_available, extension_float_dtypes_available,
) )
from pyspark.sql.types import ByteType, DecimalType, IntegerType, LongType from pyspark.sql.types import DecimalType, IntegralType
from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.pandasutils import PandasOnSparkTestCase
@ -328,7 +328,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def test_invert(self): def test_invert(self):
for pser, psser in self.numeric_pser_psser_pairs: for pser, psser in self.numeric_pser_psser_pairs:
if type(psser.spark.data_type) in [ByteType, IntegerType, LongType]: if isinstance(psser.spark.data_type, IntegralType):
self.assert_eq(~pser, ~psser) self.assert_eq(~pser, ~psser)
else: else:
self.assertRaises(TypeError, lambda: ~psser) self.assertRaises(TypeError, lambda: ~psser)