[SPARK-35343][PYTHON] Make the conversion from/to pandas data-type-based for non-ExtensionDtypes

### What changes were proposed in this pull request?

Make the conversion from/to pandas (for non-ExtensionDtype) data-type-based.
NOTE: Ops class per ExtensionDtype and its data-type-based from/to pandas will be implemented in a separate PR as https://issues.apache.org/jira/browse/SPARK-35614.

### Why are the changes needed?

The conversion from/to pandas includes logic for checking data types and behaving accordingly.
That makes code hard to change or maintain.
Since we have introduced the Ops class per non-ExtensionDtype data type, we ought to make the conversion from/to pandas data-type-based for non-ExtensionDtypes.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

Closes #32592 from xinrong-databricks/datatypeop_pd_conversion.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-06-07 13:12:12 -07:00 committed by Takuya UESHIN
parent 6c3b7f92cf
commit 04a8d2cbcf
19 changed files with 481 additions and 37 deletions

View file

@ -615,8 +615,10 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.data_type_ops.test_complex_ops",
"pyspark.pandas.tests.data_type_ops.test_date_ops",
"pyspark.pandas.tests.data_type_ops.test_datetime_ops",
"pyspark.pandas.tests.data_type_ops.test_null_ops",
"pyspark.pandas.tests.data_type_ops.test_num_ops",
"pyspark.pandas.tests.data_type_ops.test_string_ops",
"pyspark.pandas.tests.data_type_ops.test_udt_ops",
"pyspark.pandas.tests.indexes.test_category",
"pyspark.pandas.tests.plot.test_frame_plot",
"pyspark.pandas.tests.plot.test_frame_plot_matplotlib",

View file

@ -16,9 +16,11 @@
#
import numbers
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from typing import Any, TYPE_CHECKING, Union
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql.types import (
@ -30,14 +32,15 @@ from pyspark.sql.types import (
FractionalType,
IntegralType,
MapType,
NullType,
NumericType,
StringType,
StructType,
TimestampType,
UserDefinedType,
)
import pyspark.sql.types as types
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.typedef import Dtype
if TYPE_CHECKING:
@ -47,6 +50,8 @@ if TYPE_CHECKING:
def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool = True) -> bool:
"""Check whether the operand is valid for arithmetic operations against numerics."""
from pyspark.pandas.base import IndexOpsMixin
if isinstance(operand, numbers.Number) and not isinstance(operand, bool):
return True
elif isinstance(operand, IndexOpsMixin):
@ -66,6 +71,8 @@ def transform_boolean_operand_to_numeric(operand: Any, spark_type: types.DataTyp
Return the transformed operand if the operand is a boolean IndexOpsMixin,
otherwise return the original operand.
"""
from pyspark.pandas.base import IndexOpsMixin
if isinstance(operand, IndexOpsMixin) and isinstance(operand.spark.data_type, BooleanType):
return operand.spark.transform(lambda scol: scol.cast(spark_type))
else:
@ -82,11 +89,13 @@ class DataTypeOps(object, metaclass=ABCMeta):
from pyspark.pandas.data_type_ops.complex_ops import ArrayOps, MapOps, StructOps
from pyspark.pandas.data_type_ops.date_ops import DateOps
from pyspark.pandas.data_type_ops.datetime_ops import DatetimeOps
from pyspark.pandas.data_type_ops.null_ops import NullOps
from pyspark.pandas.data_type_ops.num_ops import (
IntegralOps,
FractionalOps,
)
from pyspark.pandas.data_type_ops.string_ops import StringOps
from pyspark.pandas.data_type_ops.udt_ops import UDTOps
if isinstance(dtype, CategoricalDtype):
return object.__new__(CategoricalOps)
@ -110,6 +119,10 @@ class DataTypeOps(object, metaclass=ABCMeta):
return object.__new__(MapOps)
elif isinstance(spark_type, StructType):
return object.__new__(StructOps)
elif isinstance(spark_type, NullType):
return object.__new__(NullOps)
elif isinstance(spark_type, UserDefinedType):
return object.__new__(UDTOps)
else:
raise TypeError("Type %s was not understood." % dtype)
@ -118,7 +131,6 @@ class DataTypeOps(object, metaclass=ABCMeta):
self.spark_type = spark_type
@property
@abstractmethod
def pretty_name(self) -> str:
raise NotImplementedError()
@ -163,3 +175,11 @@ class DataTypeOps(object, metaclass=ABCMeta):
def rpow(self, left, right) -> Union["Series", "Index"]:
raise TypeError("Exponentiation can not be applied to %s." % self.pretty_name)
def restore(self, col: pd.Series) -> pd.Series:
"""Restore column when to_pandas."""
return col
def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.replace({np.nan: None})

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
import pandas as pd
from pyspark.pandas.data_type_ops.base import DataTypeOps
@ -26,3 +28,13 @@ class CategoricalOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return "categoricals"
def restore(self, col: pd.Series) -> pd.Series:
"""Restore column when to_pandas."""
return pd.Categorical.from_codes(
col, categories=self.dtype.categories, ordered=self.dtype.ordered
)
def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.cat.codes

View file

@ -74,3 +74,7 @@ class DatetimeOps(DataTypeOps):
)
else:
raise TypeError("datetime subtraction can only be applied to datetime series.")
def prepare(self, col):
"""Prepare column when from_pandas."""
return col

View file

@ -0,0 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.data_type_ops.base import DataTypeOps
class NullOps(DataTypeOps):
"""
The class for binary operations of pandas-on-Spark objects with Spark type: NullType.
"""
@property
def pretty_name(self) -> str:
return "nulls"

View file

@ -0,0 +1,29 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.data_type_ops.base import DataTypeOps
class UDTOps(DataTypeOps):
"""
The class for binary operations of pandas-on-Spark objects with Spark type:
UserDefinedType or its subclasses.
"""
@property
def pretty_name(self) -> str:
return "user defined types"

View file

@ -25,12 +25,20 @@ import py4j
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype, is_datetime64_dtype, is_datetime64tz_dtype
from pandas.api.types import CategoricalDtype # noqa: F401
from pyspark import sql as spark
from pyspark._globals import _NoValue, _NoValueType
from pyspark.sql import functions as F, Window
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import BooleanType, DataType, StructField, StructType, LongType
from pyspark.sql.types import ( # noqa: F401
BooleanType,
DataType,
IntegralType,
LongType,
StructField,
StructType,
StringType,
)
# For running doctests and reference resolution in PyCharm.
from pyspark import pandas as ps # noqa: F401
@ -39,6 +47,7 @@ if TYPE_CHECKING:
# This is required in old Python 3.5 to prevent circular reference.
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
from pyspark.pandas.config import get_option
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.typedef import (
Dtype,
as_spark_type,
@ -951,11 +960,11 @@ class InternalFrame(object):
for col, dtype in zip(self.index_spark_column_names, self.index_dtypes)
if isinstance(dtype, extension_dtypes)
}
categorical_dtypes = {
col: dtype
for col, dtype in zip(self.index_spark_column_names, self.index_dtypes)
if isinstance(dtype, CategoricalDtype)
}
dtypes = [dtype for dtype in self.index_dtypes]
spark_types = [
self.spark_frame.select(scol).schema[0].dataType for scol in self.index_spark_columns
]
for spark_column, column_name, dtype in zip(
self.data_spark_columns, self.data_spark_column_names, self.data_dtypes
):
@ -969,8 +978,8 @@ class InternalFrame(object):
column_names.append(column_name)
if isinstance(dtype, extension_dtypes):
ext_dtypes[column_name] = dtype
elif isinstance(dtype, CategoricalDtype):
categorical_dtypes[column_name] = dtype
dtypes.append(dtype)
spark_types.append(self.spark_frame.select(spark_column).schema[0].dataType)
return dict(
index_columns=self.index_spark_column_names,
@ -979,7 +988,8 @@ class InternalFrame(object):
column_labels=self.column_labels,
column_label_names=self.column_label_names,
ext_dtypes=ext_dtypes,
categorical_dtypes=categorical_dtypes,
dtypes=dtypes,
spark_types=spark_types,
)
@staticmethod
@ -991,8 +1001,9 @@ class InternalFrame(object):
data_columns: List[str],
column_labels: List[Tuple],
column_label_names: List[Tuple],
dtypes: List[Dtype],
spark_types: List[DataType],
ext_dtypes: Dict[str, Dtype] = None,
categorical_dtypes: Dict[str, CategoricalDtype] = None
) -> pd.DataFrame:
"""
Restore pandas DataFrame indices using the metadata.
@ -1003,10 +1014,12 @@ class InternalFrame(object):
:param data_columns: the original column names for data columns.
:param column_labels: the column labels after restored.
:param column_label_names: the column label names after restored.
:param dtypes: the dtypes after restored.
:param spark_types: the spark_types.
:param ext_dtypes: the map from the original column names to extension data types.
:param categorical_dtypes: the map from the original column names to categorical types.
:return: the restored pandas DataFrame
>>> from numpy import dtype
>>> pdf = pd.DataFrame({"index": [10, 20, 30], "a": ['a', 'b', 'c'], "b": [0, 2, 1]})
>>> InternalFrame.restore_index(
... pdf,
@ -1015,8 +1028,9 @@ class InternalFrame(object):
... data_columns=["a", "b", "index"],
... column_labels=[("x",), ("y",), ("z",)],
... column_label_names=[("lv1",)],
... ext_dtypes=None,
... categorical_dtypes={"b": CategoricalDtype(categories=["i", "j", "k"])}
... dtypes=[dtype('int64'), dtype('object'),
... CategoricalDtype(categories=["i", "j", "k"]), dtype('int64')],
... spark_types=[LongType(), StringType(), StringType(), LongType()]
... ) # doctest: +NORMALIZE_WHITESPACE
lv1 x y z
idx
@ -1027,11 +1041,8 @@ class InternalFrame(object):
if ext_dtypes is not None and len(ext_dtypes) > 0:
pdf = pdf.astype(ext_dtypes, copy=True)
if categorical_dtypes is not None:
for col, dtype in categorical_dtypes.items():
pdf[col] = pd.Categorical.from_codes(
pdf[col], categories=dtype.categories, ordered=dtype.ordered
)
for col, expected_dtype, spark_type in zip(pdf.columns, dtypes, spark_types):
pdf[col] = DataTypeOps(expected_dtype, spark_type).restore(pdf[col])
append = False
for index_field in index_columns:
@ -1071,7 +1082,7 @@ class InternalFrame(object):
*,
index_dtypes: Optional[List[Dtype]] = None,
data_columns: Optional[List[str]] = None,
data_dtypes: Optional[List[Dtype]] = None
data_dtypes: Optional[List[Dtype]] = None,
) -> "InternalFrame":
"""Copy the immutable InternalFrame with the updates by the specified Spark DataFrame.
@ -1121,7 +1132,7 @@ class InternalFrame(object):
column_labels: Optional[List[Tuple]] = None,
data_dtypes: Optional[List[Dtype]] = None,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue,
keep_order: bool = True
keep_order: bool = True,
) -> "InternalFrame":
"""
Copy the immutable InternalFrame with the updates by the specified Spark Columns or Series.
@ -1225,7 +1236,7 @@ class InternalFrame(object):
scol: spark.Column,
*,
dtype: Optional[Dtype] = None,
keep_order: bool = True
keep_order: bool = True,
) -> "InternalFrame":
"""
Copy the immutable InternalFrame with the updates by the specified Spark Column.
@ -1273,7 +1284,7 @@ class InternalFrame(object):
column_labels: Union[Optional[List[Tuple]], _NoValueType] = _NoValue,
data_spark_columns: Union[Optional[List[spark.Column]], _NoValueType] = _NoValue,
data_dtypes: Union[Optional[List[Dtype]], _NoValueType] = _NoValue,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue,
) -> "InternalFrame":
"""Copy the immutable InternalFrame.
@ -1423,13 +1434,9 @@ class InternalFrame(object):
index_dtypes = list(reset_index.dtypes)[:index_nlevels]
data_dtypes = list(reset_index.dtypes)[index_nlevels:]
for name, col in reset_index.iteritems():
dt = col.dtype
if is_datetime64_dtype(dt) or is_datetime64tz_dtype(dt):
continue
elif isinstance(dt, CategoricalDtype):
col = col.cat.codes
reset_index[name] = col.replace({np.nan: None})
for col, dtype in zip(reset_index.columns, reset_index.dtypes):
spark_type = infer_pd_series_spark_type(reset_index[col], dtype)
reset_index[col] = DataTypeOps(dtype, spark_type).prepare(reset_index[col])
return reset_index, index_columns, index_dtypes, data_columns, data_dtypes

View file

@ -122,6 +122,13 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)
def test_from_to_pandas(self):
data = [b"1", b"2", b"3"]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -229,6 +229,13 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: datetime.date(1994, 1, 1) % self.psser)
self.assertRaises(TypeError, lambda: True % self.psser)
def test_from_to_pandas(self):
data = [True, True, False]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -115,6 +115,13 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)
def test_from_to_pandas(self):
data = [1, "x", "y"]
pser = pd.Series(data, dtype="category")
psser = ps.Series(data, dtype="category")
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -190,6 +190,11 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)
def test_from_to_pandas(self):
for pser, psser in zip(self.psers, self.pssers):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -147,6 +147,13 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: 1 ** self.psser)
self.assertRaises(TypeError, lambda: self.some_date ** self.psser)
def test_from_to_pandas(self):
data = [datetime.date(1994, 1, 31), datetime.date(1994, 2, 1), datetime.date(1994, 2, 2)]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -147,6 +147,13 @@ class DatetimeOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: 1 ** self.psser)
self.assertRaises(TypeError, lambda: self.some_datetime ** self.psser)
def test_from_to_pandas(self):
data = pd.date_range("1994-1-31 10:30:15", periods=3, freq="M")
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -0,0 +1,136 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
import pyspark.pandas as ps
from pyspark.pandas.config import option_context
from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils
from pyspark.testing.pandasutils import PandasOnSparkTestCase
class NullOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property
def pser(self):
return pd.Series([None, None, None])
@property
def psser(self):
return ps.from_pandas(self.pser)
def test_add(self):
self.assertRaises(TypeError, lambda: self.psser + "x")
self.assertRaises(TypeError, lambda: self.psser + 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser + psser)
def test_sub(self):
self.assertRaises(TypeError, lambda: self.psser - "x")
self.assertRaises(TypeError, lambda: self.psser - 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser - psser)
def test_mul(self):
self.assertRaises(TypeError, lambda: self.psser * "x")
self.assertRaises(TypeError, lambda: self.psser * 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser * psser)
def test_truediv(self):
self.assertRaises(TypeError, lambda: self.psser / "x")
self.assertRaises(TypeError, lambda: self.psser / 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser / psser)
def test_floordiv(self):
self.assertRaises(TypeError, lambda: self.psser // "x")
self.assertRaises(TypeError, lambda: self.psser // 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser // psser)
def test_mod(self):
self.assertRaises(TypeError, lambda: self.psser % "x")
self.assertRaises(TypeError, lambda: self.psser % 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser % psser)
def test_pow(self):
self.assertRaises(TypeError, lambda: self.psser ** "x")
self.assertRaises(TypeError, lambda: self.psser ** 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser ** psser)
def test_radd(self):
self.assertRaises(TypeError, lambda: "x" + self.psser)
self.assertRaises(TypeError, lambda: 1 + self.psser)
def test_rsub(self):
self.assertRaises(TypeError, lambda: "x" - self.psser)
self.assertRaises(TypeError, lambda: 1 - self.psser)
def test_rmul(self):
self.assertRaises(TypeError, lambda: "x" * self.psser)
self.assertRaises(TypeError, lambda: 2 * self.psser)
def test_rtruediv(self):
self.assertRaises(TypeError, lambda: "x" / self.psser)
self.assertRaises(TypeError, lambda: 1 / self.psser)
def test_rfloordiv(self):
self.assertRaises(TypeError, lambda: "x" // self.psser)
self.assertRaises(TypeError, lambda: 1 // self.psser)
def test_rmod(self):
self.assertRaises(TypeError, lambda: 1 % self.psser)
def test_rpow(self):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)
def test_from_to_pandas(self):
data = [None, None, None]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.data_type_ops.test_null_ops import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -254,6 +254,11 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: datetime.date(1994, 1, 1) % psser)
self.assertRaises(TypeError, lambda: datetime.datetime(1994, 1, 1) % psser)
def test_from_to_pandas(self):
for pser, psser in self.numeric_pser_psser_pairs:
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -129,6 +129,13 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)
def test_from_to_pandas(self):
data = ["x", "y", "z"]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest

View file

@ -0,0 +1,139 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
import pyspark.pandas as ps
from pyspark.ml.linalg import SparseVector
from pyspark.pandas.config import option_context
from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils
from pyspark.testing.pandasutils import PandasOnSparkTestCase
class UDTOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@property
def pser(self):
sparse_values = {0: 0.1, 1: 1.1}
return pd.Series([SparseVector(len(sparse_values), sparse_values)])
@property
def psser(self):
return ps.from_pandas(self.pser)
def test_add(self):
self.assertRaises(TypeError, lambda: self.psser + "x")
self.assertRaises(TypeError, lambda: self.psser + 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser + psser)
def test_sub(self):
self.assertRaises(TypeError, lambda: self.psser - "x")
self.assertRaises(TypeError, lambda: self.psser - 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser - psser)
def test_mul(self):
self.assertRaises(TypeError, lambda: self.psser * "x")
self.assertRaises(TypeError, lambda: self.psser * 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser * psser)
def test_truediv(self):
self.assertRaises(TypeError, lambda: self.psser / "x")
self.assertRaises(TypeError, lambda: self.psser / 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser / psser)
def test_floordiv(self):
self.assertRaises(TypeError, lambda: self.psser // "x")
self.assertRaises(TypeError, lambda: self.psser // 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser // psser)
def test_mod(self):
self.assertRaises(TypeError, lambda: self.psser % "x")
self.assertRaises(TypeError, lambda: self.psser % 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser % psser)
def test_pow(self):
self.assertRaises(TypeError, lambda: self.psser ** "x")
self.assertRaises(TypeError, lambda: self.psser ** 1)
with option_context("compute.ops_on_diff_frames", True):
for psser in self.pssers:
self.assertRaises(TypeError, lambda: self.psser ** psser)
def test_radd(self):
self.assertRaises(TypeError, lambda: "x" + self.psser)
self.assertRaises(TypeError, lambda: 1 + self.psser)
def test_rsub(self):
self.assertRaises(TypeError, lambda: "x" - self.psser)
self.assertRaises(TypeError, lambda: 1 - self.psser)
def test_rmul(self):
self.assertRaises(TypeError, lambda: "x" * self.psser)
self.assertRaises(TypeError, lambda: 2 * self.psser)
def test_rtruediv(self):
self.assertRaises(TypeError, lambda: "x" / self.psser)
self.assertRaises(TypeError, lambda: 1 / self.psser)
def test_rfloordiv(self):
self.assertRaises(TypeError, lambda: "x" // self.psser)
self.assertRaises(TypeError, lambda: 1 // self.psser)
def test_rmod(self):
self.assertRaises(TypeError, lambda: 1 % self.psser)
def test_rpow(self):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)
def test_from_to_pandas(self):
sparse_values = {0: 0.1, 1: 1.1}
sparse_vector = SparseVector(len(sparse_values), sparse_values)
pser = pd.Series([sparse_vector])
psser = ps.Series([sparse_vector])
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.data_type_ops.test_udt_ops import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -1255,7 +1255,7 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
self.assert_eq(psmidx.is_monotonic_decreasing, False)
else:
[(-5, None), (-4, None), (-3, None), (-2, None), (-1, None)]
# For [(-5, None), (-4, None), (-3, None), (-2, None), (-1, None)]
psdf = ps.DataFrame({"a": [-5, -4, -3, -2, -1], "b": [1, 1, 1, 1, 1]})
psdf["b"] = None
psmidx = psdf.set_index(["a", "b"]).index
@ -1263,7 +1263,7 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
self.assert_eq(psmidx.is_monotonic_increasing, pmidx.is_monotonic_increasing)
self.assert_eq(psmidx.is_monotonic_decreasing, pmidx.is_monotonic_decreasing)
[(None, "e"), (None, "c"), (None, "b"), (None, "d"), (None, "a")]
# For [(None, "e"), (None, "c"), (None, "b"), (None, "d"), (None, "a")]
psdf = ps.DataFrame({"a": [1, 1, 1, 1, 1], "b": ["e", "c", "b", "d", "a"]})
psdf["a"] = None
psmidx = psdf.set_index(["a", "b"]).index
@ -1271,7 +1271,7 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
self.assert_eq(psmidx.is_monotonic_increasing, pmidx.is_monotonic_increasing)
self.assert_eq(psmidx.is_monotonic_decreasing, pmidx.is_monotonic_decreasing)
[(None, None), (None, None), (None, None), (None, None), (None, None)]
# For [(None, None), (None, None), (None, None), (None, None), (None, None)]
psdf = ps.DataFrame({"a": [1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1]})
psdf["a"] = None
psdf["b"] = None
@ -1279,7 +1279,8 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
pmidx = psmidx.to_pandas()
self.assert_eq(psmidx.is_monotonic_increasing, pmidx.is_monotonic_increasing)
self.assert_eq(psmidx.is_monotonic_decreasing, pmidx.is_monotonic_decreasing)
[(None, None)]
# For [(None, None)]
psdf = ps.DataFrame({"a": [1], "b": [1]})
psdf["a"] = None
psdf["b"] = None

View file

@ -58,6 +58,20 @@ class InternalFrameTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(internal.to_pandas_frame, pdf1)
# categorical column
pdf2 = pd.DataFrame({0: [1, 2, 3], 1: pd.Categorical([4, 5, 6])})
internal = InternalFrame.from_pandas(pdf2)
sdf = internal.spark_frame
self.assert_eq(internal.index_spark_column_names, [SPARK_DEFAULT_INDEX_NAME])
self.assert_eq(internal.index_names, [None])
self.assert_eq(internal.column_labels, [(0,), (1,)])
self.assert_eq(internal.data_spark_column_names, ["0", "1"])
self.assertTrue(spark_column_equals(internal.spark_column_for((0,)), sdf["0"]))
self.assertTrue(spark_column_equals(internal.spark_column_for((1,)), sdf["1"]))
self.assert_eq(internal.to_pandas_frame, pdf2)
# multi-index
pdf.set_index("a", append=True, inplace=True)