[SPARK-36708][PYTHON] Support numpy.typing for annotating ArrayType in pandas API on Spark

### What changes were proposed in this pull request?

This PR adds the support of understanding `numpy.typing` package that's added from NumPy 1.21.

### Why are the changes needed?

For user-friendly return type specification in type hints for function apply APIs in pandas API on Spark.

### Does this PR introduce _any_ user-facing change?

Yes, this PR will enable users to specify return type as `numpy.typing.NDArray[...]` to internally specify pandas UDF's return type.

For example,

```python
import pandas as pd
import pyspark.pandas as ps

pdf = pd.DataFrame(
    {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
    index=np.random.rand(9),
)
psdf = ps.from_pandas(pdf)

def func(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]:
    return x

psdf.pandas_on_spark.apply_batch(func)
```

### How was this patch tested?

Unittest and e2e tests were added.

Closes #34028 from HyukjinKwon/SPARK-36708.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Hyukjin Kwon 2021-09-23 10:50:10 +09:00
parent 6a5ee0283c
commit cc2fcb4794
3 changed files with 68 additions and 24 deletions

View file

@ -21,6 +21,7 @@ import inspect
import sys
import unittest
from io import StringIO
from typing import List
import numpy as np
import pandas as pd
@ -4649,6 +4650,32 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(sorted(actual["a"].to_numpy()), sorted(expected["a"].to_numpy()))
self.assert_eq(sorted(actual["b"].to_numpy()), sorted(expected["b"].to_numpy()))
pdf = pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
index=np.random.rand(9),
)
psdf = ps.from_pandas(pdf)
def identify3(x) -> ps.DataFrame[float, [int, List[int]]]:
return x
actual = psdf.pandas_on_spark.apply_batch(identify3)
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
import numpy.typing as ntp
psdf = ps.from_pandas(pdf)
def identify4(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: # type: ignore
return x
actual = psdf.pandas_on_spark.apply_batch(identify4)
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)
def test_transform_batch(self):
pdf = pd.DataFrame(
{

View file

@ -19,6 +19,7 @@ import sys
import unittest
import datetime
import decimal
from distutils.version import LooseVersion
from typing import List
import pandas
@ -334,29 +335,6 @@ class TypeHintTests(unittest.TestCase):
decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
# ArrayType
np.ndarray: (np.dtype("object"), ArrayType(StringType())),
List[bytes]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.character]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())),
List[bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[datetime.date]: (np.dtype("object"), ArrayType(DateType())),
List[np.int8]: (np.dtype("object"), ArrayType(ByteType())),
List[np.byte]: (np.dtype("object"), ArrayType(ByteType())),
List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))),
List[float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float32]: (np.dtype("object"), ArrayType(FloatType())),
List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())),
List[int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int64]: (np.dtype("object"), ArrayType(LongType())),
List[np.int16]: (np.dtype("object"), ArrayType(ShortType())),
List[str]: (np.dtype("object"), ArrayType(StringType())),
List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())),
List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())),
List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())),
# CategoricalDtype
CategoricalDtype(categories=["a", "b", "c"]): (
CategoricalDtype(categories=["a", "b", "c"]),
@ -368,6 +346,28 @@ class TypeHintTests(unittest.TestCase):
self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type))
if isinstance(numpy_or_python_type, CategoricalDtype):
# Nested CategoricalDtype is not yet supported.
continue
self.assertEqual(as_spark_type(List[numpy_or_python_type]), ArrayType(spark_type))
self.assertEqual(
pandas_on_spark_type(List[numpy_or_python_type]),
(np.dtype("object"), ArrayType(spark_type)),
)
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
import numpy.typing as ntp
self.assertEqual(
as_spark_type(ntp.NDArray[numpy_or_python_type]), ArrayType(spark_type)
)
self.assertEqual(
pandas_on_spark_type(ntp.NDArray[numpy_or_python_type]),
(np.dtype("object"), ArrayType(spark_type)),
)
with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
as_spark_type(np.dtype("uint64"))

View file

@ -20,8 +20,10 @@ Utilities to deal with types. This is mostly focused on python3.
"""
import datetime
import decimal
import sys
import typing
from collections import Iterable
from distutils.version import LooseVersion
from inspect import getfullargspec, isclass
from typing import ( # noqa: F401
Any,
@ -152,6 +154,19 @@ def as_spark_type(
- dictionaries of field_name -> type
- Python3's typing system
"""
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
if (
hasattr(tpe, "__origin__")
and tpe.__origin__ is np.ndarray # type: ignore
and hasattr(tpe, "__args__")
and len(tpe.__args__) > 1 # type: ignore
):
# numpy.typing.NDArray
return types.ArrayType(
as_spark_type(tpe.__args__[1].__args__[0], raise_error=raise_error) # type: ignore
)
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
pass
# ArrayType
@ -568,7 +583,9 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
else:
parameters = getattr(tuple_type, "__args__")
index_parameters = [p for p in parameters if issubclass(p, IndexNameTypeHolder)]
index_parameters = [
p for p in parameters if isclass(p) and issubclass(p, IndexNameTypeHolder)
]
data_parameters = [p for p in parameters if p not in index_parameters]
assert len(data_parameters) > 0, "Type hints for data must not be empty."