[SPARK-36708][PYTHON] Support numpy.typing for annotating ArrayType in pandas API on Spark
### What changes were proposed in this pull request? This PR adds the support of understanding `numpy.typing` package that's added from NumPy 1.21. ### Why are the changes needed? For user-friendly return type specification in type hints for function apply APIs in pandas API on Spark. ### Does this PR introduce _any_ user-facing change? Yes, this PR will enable users to specify return type as `numpy.typing.NDArray[...]` to internally specify pandas UDF's return type. For example, ```python import pandas as pd import pyspark.pandas as ps pdf = pd.DataFrame( {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]}, index=np.random.rand(9), ) psdf = ps.from_pandas(pdf) def func(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: return x psdf.pandas_on_spark.apply_batch(func) ``` ### How was this patch tested? Unittest and e2e tests were added. Closes #34028 from HyukjinKwon/SPARK-36708. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
6a5ee0283c
commit
cc2fcb4794
|
@ -21,6 +21,7 @@ import inspect
|
|||
import sys
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -4649,6 +4650,32 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
self.assert_eq(sorted(actual["a"].to_numpy()), sorted(expected["a"].to_numpy()))
|
||||
self.assert_eq(sorted(actual["b"].to_numpy()), sorted(expected["b"].to_numpy()))
|
||||
|
||||
pdf = pd.DataFrame(
|
||||
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
|
||||
index=np.random.rand(9),
|
||||
)
|
||||
psdf = ps.from_pandas(pdf)
|
||||
|
||||
def identify3(x) -> ps.DataFrame[float, [int, List[int]]]:
|
||||
return x
|
||||
|
||||
actual = psdf.pandas_on_spark.apply_batch(identify3)
|
||||
actual.columns = ["a", "b"]
|
||||
self.assert_eq(actual, pdf)
|
||||
|
||||
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
|
||||
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
|
||||
import numpy.typing as ntp
|
||||
|
||||
psdf = ps.from_pandas(pdf)
|
||||
|
||||
def identify4(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: # type: ignore
|
||||
return x
|
||||
|
||||
actual = psdf.pandas_on_spark.apply_batch(identify4)
|
||||
actual.columns = ["a", "b"]
|
||||
self.assert_eq(actual, pdf)
|
||||
|
||||
def test_transform_batch(self):
|
||||
pdf = pd.DataFrame(
|
||||
{
|
||||
|
|
|
@ -19,6 +19,7 @@ import sys
|
|||
import unittest
|
||||
import datetime
|
||||
import decimal
|
||||
from distutils.version import LooseVersion
|
||||
from typing import List
|
||||
|
||||
import pandas
|
||||
|
@ -334,29 +335,6 @@ class TypeHintTests(unittest.TestCase):
|
|||
decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
|
||||
# ArrayType
|
||||
np.ndarray: (np.dtype("object"), ArrayType(StringType())),
|
||||
List[bytes]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[np.character]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[bool]: (np.dtype("object"), ArrayType(BooleanType())),
|
||||
List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())),
|
||||
List[datetime.date]: (np.dtype("object"), ArrayType(DateType())),
|
||||
List[np.int8]: (np.dtype("object"), ArrayType(ByteType())),
|
||||
List[np.byte]: (np.dtype("object"), ArrayType(ByteType())),
|
||||
List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))),
|
||||
List[float]: (np.dtype("object"), ArrayType(DoubleType())),
|
||||
List[np.float]: (np.dtype("object"), ArrayType(DoubleType())),
|
||||
List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())),
|
||||
List[np.float32]: (np.dtype("object"), ArrayType(FloatType())),
|
||||
List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())),
|
||||
List[int]: (np.dtype("object"), ArrayType(LongType())),
|
||||
List[np.int]: (np.dtype("object"), ArrayType(LongType())),
|
||||
List[np.int64]: (np.dtype("object"), ArrayType(LongType())),
|
||||
List[np.int16]: (np.dtype("object"), ArrayType(ShortType())),
|
||||
List[str]: (np.dtype("object"), ArrayType(StringType())),
|
||||
List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())),
|
||||
List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())),
|
||||
List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())),
|
||||
# CategoricalDtype
|
||||
CategoricalDtype(categories=["a", "b", "c"]): (
|
||||
CategoricalDtype(categories=["a", "b", "c"]),
|
||||
|
@ -368,6 +346,28 @@ class TypeHintTests(unittest.TestCase):
|
|||
self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
|
||||
self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type))
|
||||
|
||||
if isinstance(numpy_or_python_type, CategoricalDtype):
|
||||
# Nested CategoricalDtype is not yet supported.
|
||||
continue
|
||||
|
||||
self.assertEqual(as_spark_type(List[numpy_or_python_type]), ArrayType(spark_type))
|
||||
self.assertEqual(
|
||||
pandas_on_spark_type(List[numpy_or_python_type]),
|
||||
(np.dtype("object"), ArrayType(spark_type)),
|
||||
)
|
||||
|
||||
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
|
||||
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
|
||||
import numpy.typing as ntp
|
||||
|
||||
self.assertEqual(
|
||||
as_spark_type(ntp.NDArray[numpy_or_python_type]), ArrayType(spark_type)
|
||||
)
|
||||
self.assertEqual(
|
||||
pandas_on_spark_type(ntp.NDArray[numpy_or_python_type]),
|
||||
(np.dtype("object"), ArrayType(spark_type)),
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
|
||||
as_spark_type(np.dtype("uint64"))
|
||||
|
||||
|
|
|
@ -20,8 +20,10 @@ Utilities to deal with types. This is mostly focused on python3.
|
|||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import sys
|
||||
import typing
|
||||
from collections import Iterable
|
||||
from distutils.version import LooseVersion
|
||||
from inspect import getfullargspec, isclass
|
||||
from typing import ( # noqa: F401
|
||||
Any,
|
||||
|
@ -152,6 +154,19 @@ def as_spark_type(
|
|||
- dictionaries of field_name -> type
|
||||
- Python3's typing system
|
||||
"""
|
||||
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
|
||||
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
|
||||
if (
|
||||
hasattr(tpe, "__origin__")
|
||||
and tpe.__origin__ is np.ndarray # type: ignore
|
||||
and hasattr(tpe, "__args__")
|
||||
and len(tpe.__args__) > 1 # type: ignore
|
||||
):
|
||||
# numpy.typing.NDArray
|
||||
return types.ArrayType(
|
||||
as_spark_type(tpe.__args__[1].__args__[0], raise_error=raise_error) # type: ignore
|
||||
)
|
||||
|
||||
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
|
||||
pass
|
||||
# ArrayType
|
||||
|
@ -568,7 +583,9 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
else:
|
||||
parameters = getattr(tuple_type, "__args__")
|
||||
|
||||
index_parameters = [p for p in parameters if issubclass(p, IndexNameTypeHolder)]
|
||||
index_parameters = [
|
||||
p for p in parameters if isclass(p) and issubclass(p, IndexNameTypeHolder)
|
||||
]
|
||||
data_parameters = [p for p in parameters if p not in index_parameters]
|
||||
assert len(data_parameters) > 0, "Type hints for data must not be empty."
|
||||
|
||||
|
|
Loading…
Reference in a new issue