[SPARK-35522][PYTHON] Introduce BinaryOps for BinaryType
### What changes were proposed in this pull request? BinaryType, which represents byte sequence values in Spark, doesn't support data-type-based operations yet. We are going to introduce BinaryOps for it. ### Why are the changes needed? The data-type-based-operations class should be set for each individual data type, including BinaryType. In addition, BinaryType has its special way of addition, which means concatenation. ### Does this PR introduce _any_ user-facing change? Yes. Before the change: ```py >>> import pyspark.pandas as ps >>> psser = ps.Series([b'1', b'2', b'3']) >>> psser + psser Traceback (most recent call last): ... TypeError: Type object was not understood. >>> psser + b'1' Traceback (most recent call last): ... TypeError: Type object was not understood. ``` After the change: ```py >>> import pyspark.pandas as ps >>> psser = ps.Series([b'1', b'2', b'3']) >>> psser + psser 0 [49, 49] 1 [50, 50] 2 [51, 51] dtype: object >>> psser + b'1' 0 [49, 49] 1 [50, 49] 2 [51, 49] dtype: object ``` ### How was this patch tested? Unit tests. Closes #32665 from xinrong-databricks/datatypeops_binary. Lead-authored-by: Xinrong Meng <xinrong.meng@databricks.com> Co-authored-by: xinrong-databricks <47337188+xinrong-databricks@users.noreply.github.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
266608d50e
commit
8cc7232ffa
|
@ -611,6 +611,7 @@ pyspark_pandas = Module(
|
|||
"pyspark.pandas.spark.utils",
|
||||
"pyspark.pandas.typedef.typehints",
|
||||
# unittests
|
||||
"pyspark.pandas.tests.data_type_ops.test_binary_ops",
|
||||
"pyspark.pandas.tests.data_type_ops.test_boolean_ops",
|
||||
"pyspark.pandas.tests.data_type_ops.test_categorical_ops",
|
||||
"pyspark.pandas.tests.data_type_ops.test_complex_ops",
|
||||
|
|
|
@ -22,6 +22,7 @@ from pandas.api.types import CategoricalDtype
|
|||
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
BinaryType,
|
||||
BooleanType,
|
||||
DataType,
|
||||
DateType,
|
||||
|
@ -44,6 +45,7 @@ class DataTypeOps(object, metaclass=ABCMeta):
|
|||
"""The base class for binary operations of pandas-on-Spark objects (of different data types)."""
|
||||
|
||||
def __new__(cls, dtype: Dtype, spark_type: DataType):
|
||||
from pyspark.pandas.data_type_ops.binary_ops import BinaryOps
|
||||
from pyspark.pandas.data_type_ops.boolean_ops import BooleanOps
|
||||
from pyspark.pandas.data_type_ops.categorical_ops import CategoricalOps
|
||||
from pyspark.pandas.data_type_ops.complex_ops import ArrayOps, MapOps, StructOps
|
||||
|
@ -69,6 +71,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
|
|||
return object.__new__(DatetimeOps)
|
||||
elif isinstance(spark_type, DateType):
|
||||
return object.__new__(DateOps)
|
||||
elif isinstance(spark_type, BinaryType):
|
||||
return object.__new__(BinaryOps)
|
||||
elif isinstance(spark_type, ArrayType):
|
||||
return object.__new__(ArrayOps)
|
||||
elif isinstance(spark_type, MapType):
|
||||
|
|
53
python/pyspark/pandas/data_type_ops/binary_ops.py
Normal file
53
python/pyspark/pandas/data_type_ops/binary_ops.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import BinaryType
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
|
||||
|
||||
|
||||
class BinaryOps(DataTypeOps):
|
||||
"""
|
||||
The class for binary operations of pandas-on-Spark objects with BinaryType.
|
||||
"""
|
||||
|
||||
@property
|
||||
def pretty_name(self) -> str:
|
||||
return 'binaries'
|
||||
|
||||
def add(self, left, right) -> Union["Series", "Index"]:
|
||||
if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, BinaryType):
|
||||
return column_op(F.concat)(left, right)
|
||||
elif isinstance(right, bytes):
|
||||
return column_op(F.concat)(left, F.lit(right))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Concatenation can not be applied to %s and the given type." % self.pretty_name)
|
||||
|
||||
def radd(self, left, right) -> Union["Series", "Index"]:
|
||||
if isinstance(right, bytes):
|
||||
return left._with_new_scol(F.concat(F.lit(right), left.spark.column))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Concatenation can not be applied to %s and the given type." % self.pretty_name)
|
135
python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py
Normal file
135
python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils
|
||||
from pyspark.testing.pandasutils import PandasOnSparkTestCase
|
||||
|
||||
|
||||
class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||
@property
|
||||
def pser(self):
|
||||
return pd.Series([b'1', b'2', b'3'])
|
||||
|
||||
@property
|
||||
def psser(self):
|
||||
return ps.from_pandas(self.pser)
|
||||
|
||||
def test_add(self):
|
||||
psser = self.psser
|
||||
pser = self.pser
|
||||
self.assert_eq(psser + b'1', pser + b'1')
|
||||
self.assert_eq(psser + psser, pser + pser)
|
||||
self.assert_eq(psser + psser.astype('bytes'), pser + pser.astype('bytes'))
|
||||
self.assertRaises(TypeError, lambda: psser + "x")
|
||||
self.assertRaises(TypeError, lambda: psser + 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser + psser)
|
||||
self.assert_eq(self.psser + self.psser, self.pser + self.pser)
|
||||
|
||||
def test_sub(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser - "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser - 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser - psser)
|
||||
|
||||
def test_mul(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser * "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser * 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser * psser)
|
||||
|
||||
def test_truediv(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser / "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser / 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser / psser)
|
||||
|
||||
def test_floordiv(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser // "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser // 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser // psser)
|
||||
|
||||
def test_mod(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser % "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser % 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser % psser)
|
||||
|
||||
def test_pow(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser ** "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser ** 1)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
for psser in self.pssers:
|
||||
self.assertRaises(TypeError, lambda: self.psser ** psser)
|
||||
|
||||
def test_radd(self):
|
||||
self.assert_eq(b'1' + self.psser, b'1' + self.pser)
|
||||
self.assertRaises(TypeError, lambda: "x" + self.psser)
|
||||
self.assertRaises(TypeError, lambda: 1 + self.psser)
|
||||
|
||||
def test_rsub(self):
|
||||
self.assertRaises(TypeError, lambda: "x" - self.psser)
|
||||
self.assertRaises(TypeError, lambda: 1 - self.psser)
|
||||
|
||||
def test_rmul(self):
|
||||
self.assertRaises(TypeError, lambda: "x" * self.psser)
|
||||
self.assertRaises(TypeError, lambda: 2 * self.psser)
|
||||
|
||||
def test_rtruediv(self):
|
||||
self.assertRaises(TypeError, lambda: "x" / self.psser)
|
||||
self.assertRaises(TypeError, lambda: 1 / self.psser)
|
||||
|
||||
def test_rfloordiv(self):
|
||||
self.assertRaises(TypeError, lambda: "x" // self.psser)
|
||||
self.assertRaises(TypeError, lambda: 1 // self.psser)
|
||||
|
||||
def test_rmod(self):
|
||||
self.assertRaises(TypeError, lambda: 1 % self.psser)
|
||||
|
||||
def test_rpow(self):
|
||||
self.assertRaises(TypeError, lambda: "x" ** self.psser)
|
||||
self.assertRaises(TypeError, lambda: 1 ** self.psser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.data_type_ops.test_binary_ops import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
Loading…
Reference in a new issue