[SPARK-36211][PYTHON] Correct typing of udf
return value
The following code should type-check: ```python3 import uuid import pyspark.sql.functions as F my_udf = F.udf(lambda: str(uuid.uuid4())).asNondeterministic() ``` ### What changes were proposed in this pull request? The `udf` function should return a more specific type. ### Why are the changes needed? Right now, `mypy` will throw spurious errors, such as for the code given above. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? This was not tested. Sorry, I am not very familiar with this repo -- are there any typing tests? Closes #33399 from luranhe/patch-1. Lead-authored-by: Luran He <luranjhe@gmail.com> Co-authored-by: Luran He <luran.he@compass.com> Signed-off-by: zero323 <mszymkiewicz@gmail.com>
This commit is contained in:
parent
068f8d434a
commit
ede1bc6b51
|
@ -18,6 +18,7 @@
|
|||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
|
@ -30,11 +31,10 @@ import datetime
|
|||
import decimal
|
||||
|
||||
from pyspark._typing import PrimitiveType
|
||||
import pyspark.sql.column
|
||||
import pyspark.sql.types
|
||||
from pyspark.sql.column import Column
|
||||
|
||||
ColumnOrName = Union[pyspark.sql.column.Column, str]
|
||||
ColumnOrName = Union[Column, str]
|
||||
DecimalLiteral = decimal.Decimal
|
||||
DateTimeLiteral = Union[datetime.datetime, datetime.date]
|
||||
LiteralType = PrimitiveType
|
||||
|
@ -54,4 +54,10 @@ class SupportsClose(Protocol):
|
|||
def close(self, error: Exception) -> None: ...
|
||||
|
||||
class UserDefinedFunctionLike(Protocol):
|
||||
def __call__(self, *_: ColumnOrName) -> Column: ...
|
||||
func: Callable[..., Any]
|
||||
evalType: int
|
||||
deterministic: bool
|
||||
@property
|
||||
def returnType(self) -> pyspark.sql.types.DataType: ...
|
||||
def __call__(self, *args: ColumnOrName) -> Column: ...
|
||||
def asNondeterministic(self) -> UserDefinedFunctionLike: ...
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
from pyspark.sql._typing import (
|
||||
ColumnOrName,
|
||||
DataTypeOrString,
|
||||
UserDefinedFunctionLike,
|
||||
)
|
||||
from pyspark.sql.pandas.functions import ( # noqa: F401
|
||||
pandas_udf as pandas_udf,
|
||||
|
@ -359,13 +360,13 @@ def variance(col: ColumnOrName) -> Column: ...
|
|||
@overload
|
||||
def udf(
|
||||
f: Callable[..., Any], returnType: DataTypeOrString = ...
|
||||
) -> Callable[..., Column]: ...
|
||||
) -> UserDefinedFunctionLike: ...
|
||||
@overload
|
||||
def udf(
|
||||
f: DataTypeOrString = ...,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Column]]: ...
|
||||
) -> Callable[[Callable[..., Any]], UserDefinedFunctionLike]: ...
|
||||
@overload
|
||||
def udf(
|
||||
*,
|
||||
returnType: DataTypeOrString = ...,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Column]]: ...
|
||||
) -> Callable[[Callable[..., Any]], UserDefinedFunctionLike]: ...
|
||||
|
|
Loading…
Reference in a new issue