[SPARK-36103][PYTHON] Manage InternalField in DataTypeOps.invert

### What changes were proposed in this pull request?

Properly set `InternalField` for `DataTypeOps.invert`.

### Why are the changes needed?

The spark data type and nullability must be the same as the original when `DataTypeOps.invert`.
We should manage `InternalField` for this case.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #33306 from ueshin/issues/SPARK-36103/invert.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-07-13 09:22:27 +09:00 committed by Hyukjin Kwon
parent 92bf83ed0a
commit e2021daafb
3 changed files with 9 additions and 23 deletions

View file

@ -16,7 +16,7 @@
#
import numbers
from typing import cast, Any, Union
from typing import Any, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
@ -281,29 +281,19 @@ class BooleanOps(DataTypeOps):
return operand
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__lt__)(left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__le__)(left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__ge__)(left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
return column_op(Column.__gt__)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
from pyspark.pandas.base import column_op
return cast(IndexOpsLike, column_op(Column.__invert__)(operand))
return operand._with_new_scol(~operand.spark.column, field=operand._internal.data_fields[0])
class BooleanExtensionOps(BooleanOps):

View file

@ -122,9 +122,6 @@ class NumericOps(DataTypeOps):
right = transform_boolean_operand_to_numeric(right)
return column_op(rmod)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
return cast(IndexOpsLike, column_op(F.bitwise_not)(operand))
def neg(self, operand: IndexOpsLike) -> IndexOpsLike:
return cast(IndexOpsLike, column_op(Column.__neg__)(operand))
@ -214,6 +211,11 @@ class IntegralOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
return operand._with_new_scol(
F.bitwise_not(operand.spark.column), field=operand._internal.data_fields[0]
)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
@ -304,9 +306,6 @@ class FractionalOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
def isnull(self, index_ops: IndexOpsLike) -> IndexOpsLike:
return index_ops._with_new_scol(
index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column),
@ -348,9 +347,6 @@ class DecimalOps(FractionalOps):
def pretty_name(self) -> str:
return "decimal"
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
raise TypeError("< can not be applied to %s." % self.pretty_name)

View file

@ -30,7 +30,7 @@ from pyspark.pandas.typedef.typehints import (
extension_dtypes_available,
extension_float_dtypes_available,
)
from pyspark.sql.types import ByteType, DecimalType, IntegerType, LongType
from pyspark.sql.types import DecimalType, IntegralType
from pyspark.testing.pandasutils import PandasOnSparkTestCase
@ -328,7 +328,7 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def test_invert(self):
for pser, psser in self.numeric_pser_psser_pairs:
if type(psser.spark.data_type) in [ByteType, IntegerType, LongType]:
if isinstance(psser.spark.data_type, IntegralType):
self.assert_eq(~pser, ~psser)
else:
self.assertRaises(TypeError, lambda: ~psser)