[SPARK-36211][PYTHON] Correct typing of udf return value

The following code should type-check:

```python3
import uuid

import pyspark.sql.functions as F

my_udf = F.udf(lambda: str(uuid.uuid4())).asNondeterministic()
```

### What changes were proposed in this pull request?

The `udf` function should return a more specific type.

### Why are the changes needed?

Right now, `mypy` will throw spurious errors, such as for the code given above.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

This was not tested. Sorry, I am not very familiar with this repo -- are there any typing tests?

Closes #33399 from luranhe/patch-1.

Lead-authored-by: Luran He <luranjhe@gmail.com>
Co-authored-by: Luran He <luran.he@compass.com>
Signed-off-by: zero323 <mszymkiewicz@gmail.com>
This commit is contained in:
Luran He 2021-07-27 09:07:22 +02:00 committed by zero323
parent 068f8d434a
commit ede1bc6b51
No known key found for this signature in database
GPG key ID: A30CEF0C31A501EC
2 changed files with 13 additions and 6 deletions

View file

@ -18,6 +18,7 @@
from typing import (
Any,
Callable,
List,
Optional,
Tuple,
@ -30,11 +31,10 @@ import datetime
import decimal
from pyspark._typing import PrimitiveType
import pyspark.sql.column
import pyspark.sql.types
from pyspark.sql.column import Column
ColumnOrName = Union[pyspark.sql.column.Column, str]
ColumnOrName = Union[Column, str]
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
LiteralType = PrimitiveType
@ -54,4 +54,10 @@ class SupportsClose(Protocol):
def close(self, error: Exception) -> None: ...
class UserDefinedFunctionLike(Protocol):
def __call__(self, *_: ColumnOrName) -> Column: ...
func: Callable[..., Any]
evalType: int
deterministic: bool
@property
def returnType(self) -> pyspark.sql.types.DataType: ...
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> UserDefinedFunctionLike: ...

View file

@ -22,6 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
from pyspark.sql._typing import (
ColumnOrName,
DataTypeOrString,
UserDefinedFunctionLike,
)
from pyspark.sql.pandas.functions import ( # noqa: F401
pandas_udf as pandas_udf,
@ -359,13 +360,13 @@ def variance(col: ColumnOrName) -> Column: ...
@overload
def udf(
f: Callable[..., Any], returnType: DataTypeOrString = ...
) -> Callable[..., Column]: ...
) -> UserDefinedFunctionLike: ...
@overload
def udf(
f: DataTypeOrString = ...,
) -> Callable[[Callable[..., Any]], Callable[..., Column]]: ...
) -> Callable[[Callable[..., Any]], UserDefinedFunctionLike]: ...
@overload
def udf(
*,
returnType: DataTypeOrString = ...,
) -> Callable[[Callable[..., Any]], Callable[..., Column]]: ...
) -> Callable[[Callable[..., Any]], UserDefinedFunctionLike]: ...