From 665817bd4fc07b18cee0f8c6ff759288472514c2 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 25 Nov 2020 09:27:04 +0900 Subject: [PATCH] [SPARK-33457][PYTHON] Adjust mypy configuration ### What changes were proposed in this pull request? This pull request: - Adds following flags to the main mypy configuration: - [`strict_optional`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-strict_optional) - [`no_implicit_optional`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-no_implicit_optional) - [`disallow_untyped_defs`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-disallow_untyped_calls) These flags are enabled only for public API and disabled for tests and internal modules. Additionally, these PR fixes missing annotations. ### Why are the changes needed? Primary reason to propose this changes is to use standard configuration as used by typeshed project. This will allow us to be more strict, especially when interacting with JVM code. See for example https://github.com/apache/spark/pull/29122#pullrequestreview-513112882 Additionally, it will allow us to detect cases where annotations have unintentionally omitted. ### Does this PR introduce _any_ user-facing change? Annotations only. ### How was this patch tested? `dev/lint-python`. Closes #30382 from zero323/SPARK-33457. Authored-by: zero323 Signed-off-by: HyukjinKwon --- python/mypy.ini | 87 +++++++++++++++++++++++ python/pyspark/broadcast.pyi | 10 +-- python/pyspark/context.pyi | 25 +++++-- python/pyspark/ml/classification.pyi | 6 +- python/pyspark/ml/common.pyi | 10 ++- python/pyspark/ml/evaluation.pyi | 24 ++++--- python/pyspark/ml/feature.pyi | 20 ++++-- python/pyspark/ml/linalg/__init__.pyi | 36 +++++----- python/pyspark/ml/pipeline.pyi | 4 +- python/pyspark/ml/regression.pyi | 10 +-- python/pyspark/mllib/classification.pyi | 2 +- python/pyspark/mllib/clustering.pyi | 6 +- python/pyspark/mllib/common.pyi | 20 ++++-- python/pyspark/mllib/linalg/__init__.pyi | 45 +++++++----- python/pyspark/mllib/random.pyi | 2 +- python/pyspark/mllib/recommendation.pyi | 4 +- python/pyspark/mllib/stat/_statistics.pyi | 2 +- python/pyspark/rdd.pyi | 8 ++- python/pyspark/resource/profile.pyi | 2 +- python/pyspark/sql/column.pyi | 8 ++- python/pyspark/sql/context.pyi | 6 +- python/pyspark/sql/functions.pyi | 8 ++- python/pyspark/sql/session.pyi | 10 ++- python/pyspark/sql/types.pyi | 15 ++-- python/pyspark/sql/udf.pyi | 7 +- python/pyspark/streaming/context.pyi | 2 +- python/pyspark/streaming/dstream.pyi | 10 ++- python/pyspark/streaming/kinesis.pyi | 2 +- 28 files changed, 277 insertions(+), 114 deletions(-) diff --git a/python/mypy.ini b/python/mypy.ini index 4a5368a519..5103452a05 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -16,10 +16,97 @@ ; [mypy] +strict_optional = True +no_implicit_optional = True +disallow_untyped_defs = True + +; Allow untyped def in internal modules and tests + +[mypy-pyspark.daemon] +disallow_untyped_defs = False + +[mypy-pyspark.find_spark_home] +disallow_untyped_defs = False + +[mypy-pyspark._globals] +disallow_untyped_defs = False + +[mypy-pyspark.install] +disallow_untyped_defs = False + +[mypy-pyspark.java_gateway] +disallow_untyped_defs = False + +[mypy-pyspark.join] +disallow_untyped_defs = False + +[mypy-pyspark.ml.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.mllib.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.rddsampler] +disallow_untyped_defs = False + +[mypy-pyspark.resource.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.shuffle] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.util] +disallow_untyped_defs = False + +[mypy-pyspark.sql.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.types] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.typehints] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.utils] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas._typing.protocols.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.utils] +disallow_untyped_defs = False + +[mypy-pyspark.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.testing.*] +disallow_untyped_defs = False + +[mypy-pyspark.traceback_utils] +disallow_untyped_defs = False + +[mypy-pyspark.util] +disallow_untyped_defs = False + +[mypy-pyspark.worker] +disallow_untyped_defs = False + +; Ignore errors in embedded third party code [mypy-pyspark.cloudpickle.*] ignore_errors = True +; Ignore missing imports for external untyped packages + [mypy-py4j.*] ignore_missing_imports = True diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi index 4b019a509a..944cb06d41 100644 --- a/python/pyspark/broadcast.pyi +++ b/python/pyspark/broadcast.pyi @@ -17,7 +17,7 @@ # under the License. import threading -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar T = TypeVar("T") @@ -32,14 +32,14 @@ class Broadcast(Generic[T]): path: Optional[Any] = ..., sock_file: Optional[Any] = ..., ) -> None: ... - def dump(self, value: Any, f: Any) -> None: ... - def load_from_path(self, path: Any): ... - def load(self, file: Any): ... + def dump(self, value: T, f: Any) -> None: ... + def load_from_path(self, path: Any) -> T: ... + def load(self, file: Any) -> T: ... @property def value(self) -> T: ... def unpersist(self, blocking: bool = ...) -> None: ... def destroy(self, blocking: bool = ...) -> None: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ... class BroadcastPickleRegistry(threading.local): def __init__(self) -> None: ... diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi index 2789a38b3b..640a69cad0 100644 --- a/python/pyspark/context.pyi +++ b/python/pyspark/context.pyi @@ -16,7 +16,19 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NoReturn, + Optional, + Tuple, + Type, + TypeVar, +) +from types import TracebackType from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import] @@ -51,9 +63,14 @@ class SparkContext: jsc: Optional[JavaObject] = ..., profiler_cls: type = ..., ) -> None: ... - def __getnewargs__(self): ... - def __enter__(self): ... - def __exit__(self, type, value, trace): ... + def __getnewargs__(self) -> NoReturn: ... + def __enter__(self) -> SparkContext: ... + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + trace: Optional[TracebackType], + ) -> None: ... @classmethod def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ... def setLogLevel(self, logLevel: str) -> None: ... diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi index 4bde851bb1..c44176a13a 100644 --- a/python/pyspark/ml/classification.pyi +++ b/python/pyspark/ml/classification.pyi @@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier( class _JavaProbabilisticClassificationModel( ProbabilisticClassificationModel, _JavaClassificationModel[T] ): - def predictProbability(self, value: Any): ... + def predictProbability(self, value: Vector) -> Vector: ... class _ClassificationSummary(JavaWrapper): @property @@ -543,7 +543,7 @@ class RandomForestClassificationModel( @property def trees(self) -> List[DecisionTreeClassificationModel]: ... def summary(self) -> RandomForestClassificationTrainingSummary: ... - def evaluate(self, dataset) -> RandomForestClassificationSummary: ... + def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ... class RandomForestClassificationSummary(_ClassificationSummary): ... class RandomForestClassificationTrainingSummary( @@ -891,7 +891,7 @@ class FMClassifier( solver: str = ..., thresholds: Optional[Any] = ..., seed: Optional[Any] = ..., - ): ... + ) -> FMClassifier: ... def setFactorSize(self, value: int) -> FMClassifier: ... def setFitLinear(self, value: bool) -> FMClassifier: ... def setMiniBatchFraction(self, value: float) -> FMClassifier: ... diff --git a/python/pyspark/ml/common.pyi b/python/pyspark/ml/common.pyi index 7bf0ed6183..a38fc5734f 100644 --- a/python/pyspark/ml/common.pyi +++ b/python/pyspark/ml/common.pyi @@ -16,5 +16,11 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def inherit_doc(cls): ... +from typing import Any, TypeVar + +import pyspark.context + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi index ea0a9f045c..55a3ae2774 100644 --- a/python/pyspark/ml/evaluation.pyi +++ b/python/pyspark/ml/evaluation.pyi @@ -39,9 +39,12 @@ from pyspark.ml.param.shared import ( HasWeightCol, ) from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.sql.dataframe import DataFrame class Evaluator(Params, metaclass=abc.ABCMeta): - def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ... + def evaluate( + self, dataset: DataFrame, params: Optional[ParamMap] = ... + ) -> float: ... def isLargerBetter(self) -> bool: ... class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta): @@ -75,16 +78,15 @@ class BinaryClassificationEvaluator( def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ... def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ... def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ... - -def setParams( - self, - *, - rawPredictionCol: str = ..., - labelCol: str = ..., - metricName: BinaryClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - numBins: int = ... -) -> BinaryClassificationEvaluator: ... + def setParams( + self, + *, + rawPredictionCol: str = ..., + labelCol: str = ..., + metricName: BinaryClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + numBins: int = ... + ) -> BinaryClassificationEvaluator: ... class RegressionEvaluator( JavaEvaluator, diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index f5b12a5b2f..4999defdf8 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol): def getNumHashTables(self) -> int: ... class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable): - def setNumHashTables(self: P, value) -> P: ... - def setInputCol(self: P, value) -> P: ... - def setOutputCol(self: P, value) -> P: ... + def setNumHashTables(self: P, value: int) -> P: ... + def setInputCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... class _LSHModel(JavaModel, _LSHParams): def setInputCol(self: P, value: str) -> P: ... @@ -1518,7 +1518,7 @@ class ChiSqSelector( fpr: float = ..., fdr: float = ..., fwe: float = ... - ): ... + ) -> ChiSqSelector: ... def setSelectorType(self, value: str) -> ChiSqSelector: ... def setNumTopFeatures(self, value: int) -> ChiSqSelector: ... def setPercentile(self, value: float) -> ChiSqSelector: ... @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol): def getVarianceThreshold(self) -> float: ... class VarianceThresholdSelector( - JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaEstimator[VarianceThresholdSelectorModel], + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelector], + JavaMLWritable, ): def __init__( self, @@ -1615,13 +1618,16 @@ class VarianceThresholdSelector( featuresCol: str = ..., outputCol: Optional[str] = ..., varianceThreshold: float = ..., - ): ... + ) -> VarianceThresholdSelector: ... def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ... def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ... def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... class VarianceThresholdSelectorModel( - JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaModel, + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelectorModel], + JavaMLWritable, ): def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ... def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ... diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi index a576b30aec..b4fba8823b 100644 --- a/python/pyspark/ml/linalg/__init__.pyi +++ b/python/pyspark/ml/linalg/__init__.pyi @@ -17,7 +17,7 @@ # under the License. from typing import overload -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union from pyspark.ml import linalg as newlinalg # noqa: F401 from pyspark.sql.types import StructType, UserDefinedType @@ -45,7 +45,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -64,9 +64,7 @@ class DenseVector(Vector): def __init__(self, __arr: bytes) -> None: ... @overload def __init__(self, __arr: Iterable[float]) -> None: ... - @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -112,16 +110,14 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... - @staticmethod - def parse(s: str) -> SparseVector: ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... def dot(self, other: Iterable[float]) -> float64: ... def squared_distance(self, other: Iterable[float]) -> float64: ... def toArray(self) -> ndarray: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -144,13 +140,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def stringify(vector: Vector) -> str: ... @staticmethod @@ -158,8 +154,6 @@ class Vectors: @staticmethod def norm(vector: Vector, p: Union[float, str]) -> float64: ... @staticmethod - def parse(s: str) -> Vector: ... - @staticmethod def zeros(size: int) -> DenseVector: ... class Matrix: @@ -170,7 +164,7 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... + def toArray(self) -> NoReturn: ... class DenseMatrix(Matrix): values: Any @@ -186,11 +180,11 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -216,11 +210,13 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi index 44680586d7..f47e9e012a 100644 --- a/python/pyspark/ml/pipeline.pyi +++ b/python/pyspark/ml/pipeline.pyi @@ -51,7 +51,7 @@ class PipelineWriter(MLWriter): def __init__(self, instance: Pipeline) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineReader(MLReader): +class PipelineReader(MLReader[Pipeline]): cls: Type[Pipeline] def __init__(self, cls: Type[Pipeline]) -> None: ... def load(self, path: str) -> Pipeline: ... @@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter): def __init__(self, instance: PipelineModel) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineModelReader(MLReader): +class PipelineModelReader(MLReader[PipelineModel]): cls: Type[PipelineModel] def __init__(self, cls: Type[PipelineModel]) -> None: ... def load(self, path: str) -> PipelineModel: ... diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi index 5cb0e7a509..b8f1e61859 100644 --- a/python/pyspark/ml/regression.pyi +++ b/python/pyspark/ml/regression.pyi @@ -414,7 +414,7 @@ class RandomForestRegressionModel( _TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable[RandomForestRegressionModel], ): @property def trees(self) -> List[DecisionTreeRegressionModel]: ... @@ -749,10 +749,10 @@ class _FactorizationMachinesParams( initStd: Param[float] solver: Param[str] def __init__(self, *args: Any): ... - def getFactorSize(self): ... - def getFitLinear(self): ... - def getMiniBatchFraction(self): ... - def getInitStd(self): ... + def getFactorSize(self) -> int: ... + def getFitLinear(self) -> bool: ... + def getMiniBatchFraction(self) -> float: ... + def getInitStd(self) -> float: ... class FMRegressor( _JavaRegressor[FMRegressionModel], diff --git a/python/pyspark/mllib/classification.pyi b/python/pyspark/mllib/classification.pyi index c51882c87b..967b0a9f28 100644 --- a/python/pyspark/mllib/classification.pyi +++ b/python/pyspark/mllib/classification.pyi @@ -118,7 +118,7 @@ class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]): labels: ndarray pi: ndarray theta: ndarray - def __init__(self, labels, pi, theta) -> None: ... + def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ... @overload def predict(self, x: VectorLike) -> float64: ... @overload diff --git a/python/pyspark/mllib/clustering.pyi b/python/pyspark/mllib/clustering.pyi index 1c3eba17e2..b4f349612f 100644 --- a/python/pyspark/mllib/clustering.pyi +++ b/python/pyspark/mllib/clustering.pyi @@ -63,7 +63,7 @@ class BisectingKMeans: class KMeansModel(Saveable, Loader[KMeansModel]): centers: List[ndarray] - def __init__(self, centers: List[ndarray]) -> None: ... + def __init__(self, centers: List[VectorLike]) -> None: ... @property def clusterCenters(self) -> List[ndarray]: ... @property @@ -144,7 +144,9 @@ class PowerIterationClustering: class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ... class StreamingKMeansModel(KMeansModel): - def __init__(self, clusterCenters, clusterWeights) -> None: ... + def __init__( + self, clusterCenters: List[VectorLike], clusterWeights: VectorLike + ) -> None: ... @property def clusterWeights(self) -> List[float64]: ... centers: ndarray diff --git a/python/pyspark/mllib/common.pyi b/python/pyspark/mllib/common.pyi index 1df308b91b..daba212d93 100644 --- a/python/pyspark/mllib/common.pyi +++ b/python/pyspark/mllib/common.pyi @@ -16,12 +16,20 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def callMLlibFunc(name, *args): ... +from typing import Any, TypeVar + +import pyspark.context + +from py4j.java_gateway import JavaObject + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def callMLlibFunc(name: str, *args: Any) -> Any: ... class JavaModelWrapper: - def __init__(self, java_model) -> None: ... - def __del__(self): ... - def call(self, name, *a): ... + def __init__(self, java_model: JavaObject) -> None: ... + def __del__(self) -> None: ... + def call(self, name: str, *a: Any) -> Any: ... -def inherit_doc(cls): ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/mllib/linalg/__init__.pyi b/python/pyspark/mllib/linalg/__init__.pyi index c0719c535c..60d16b26f3 100644 --- a/python/pyspark/mllib/linalg/__init__.pyi +++ b/python/pyspark/mllib/linalg/__init__.pyi @@ -17,7 +17,18 @@ # under the License. from typing import overload -from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from pyspark.ml import linalg as newlinalg from pyspark.sql.types import StructType, UserDefinedType from numpy import float64, ndarray # type: ignore[import] @@ -46,7 +57,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -67,8 +78,8 @@ class DenseVector(Vector): @overload def __init__(self, __arr: Iterable[float]) -> None: ... @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def parse(s: str) -> DenseVector: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -115,7 +126,7 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... @staticmethod def parse(s: str) -> SparseVector: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -123,9 +134,9 @@ class SparseVector(Vector): def toArray(self) -> ndarray: ... def asML(self) -> newlinalg.SparseVector: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -148,13 +159,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def fromML(vec: newlinalg.DenseVector) -> DenseVector: ... @staticmethod @@ -176,8 +187,8 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... - def asML(self): ... + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.Matrix: ... class DenseMatrix(Matrix): values: Any @@ -193,12 +204,12 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def asML(self) -> newlinalg.DenseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -224,12 +235,14 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... def asML(self) -> newlinalg.SparseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload diff --git a/python/pyspark/mllib/random.pyi b/python/pyspark/mllib/random.pyi index dc5f470161..ec83170625 100644 --- a/python/pyspark/mllib/random.pyi +++ b/python/pyspark/mllib/random.pyi @@ -90,7 +90,7 @@ class RandomRDDs: def logNormalVectorRDD( sc: SparkContext, mean: float, - std, + std: float, numRows: int, numCols: int, numPartitions: Optional[int] = ..., diff --git a/python/pyspark/mllib/recommendation.pyi b/python/pyspark/mllib/recommendation.pyi index e2f1549420..4fea0acf3c 100644 --- a/python/pyspark/mllib/recommendation.pyi +++ b/python/pyspark/mllib/recommendation.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import array from collections import namedtuple @@ -27,7 +27,7 @@ from pyspark.mllib.common import JavaModelWrapper from pyspark.mllib.util import JavaLoader, JavaSaveable class Rating(namedtuple("Rating", ["user", "product", "rating"])): - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[Rating], Tuple[int, int, float]]: ... class MatrixFactorizationModel( JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel] diff --git a/python/pyspark/mllib/stat/_statistics.pyi b/python/pyspark/mllib/stat/_statistics.pyi index 4d2701d486..3834d51639 100644 --- a/python/pyspark/mllib/stat/_statistics.pyi +++ b/python/pyspark/mllib/stat/_statistics.pyi @@ -65,5 +65,5 @@ class Statistics: def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ... @staticmethod def kolmogorovSmirnovTest( - data, distName: Literal["norm"] = ..., *params: float + data: RDD[float], distName: Literal["norm"] = ..., *params: float ) -> KolmogorovSmirnovTestResult: ... diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index 35c49e952b..a277cd9f7e 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -85,12 +85,16 @@ class PythonEvalType: SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType class BoundedFloat(float): - def __new__(cls, mean: float, confidence: float, low: float, high: float): ... + def __new__( + cls, mean: float, confidence: float, low: float, high: float + ) -> BoundedFloat: ... class Partitioner: numPartitions: int partitionFunc: Callable[[Any], int] - def __init__(self, numPartitions, partitionFunc) -> None: ... + def __init__( + self, numPartitions: int, partitionFunc: Callable[[Any], int] + ) -> None: ... def __eq__(self, other: Any) -> bool: ... def __call__(self, k: Any) -> int: ... diff --git a/python/pyspark/resource/profile.pyi b/python/pyspark/resource/profile.pyi index 6763baf659..0483869243 100644 --- a/python/pyspark/resource/profile.pyi +++ b/python/pyspark/resource/profile.pyi @@ -49,7 +49,7 @@ class ResourceProfileBuilder: def __init__(self) -> None: ... def require( self, resourceRequest: Union[ExecutorResourceRequest, TaskResourceRequests] - ): ... + ) -> ResourceProfileBuilder: ... def clearExecutorResourceRequests(self) -> None: ... def clearTaskResourceRequests(self) -> None: ... @property diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi index 0fbb10053f..1f63e65b3d 100644 --- a/python/pyspark/sql/column.pyi +++ b/python/pyspark/sql/column.pyi @@ -32,7 +32,7 @@ from pyspark.sql.window import WindowSpec from py4j.java_gateway import JavaObject # type: ignore[import] class Column: - def __init__(self, JavaObject) -> None: ... + def __init__(self, jc: JavaObject) -> None: ... def __neg__(self) -> Column: ... def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... @@ -105,7 +105,11 @@ class Column: def name(self, *alias: str) -> Column: ... def cast(self, dataType: Union[DataType, str]) -> Column: ... def astype(self, dataType: Union[DataType, str]) -> Column: ... - def between(self, lowerBound, upperBound) -> Column: ... + def between( + self, + lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + ) -> Column: ... def when(self, condition: Column, value: Any) -> Column: ... def otherwise(self, value: Any) -> Column: ... def over(self, window: WindowSpec) -> Column: ... diff --git a/python/pyspark/sql/context.pyi b/python/pyspark/sql/context.pyi index 64927b37ac..915a0fe1f6 100644 --- a/python/pyspark/sql/context.pyi +++ b/python/pyspark/sql/context.pyi @@ -43,14 +43,14 @@ class SQLContext: sparkSession: SparkSession def __init__( self, - sparkContext, + sparkContext: SparkContext, sparkSession: Optional[SparkSession] = ..., jsqlContext: Optional[JavaObject] = ..., ) -> None: ... @classmethod def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ... def newSession(self) -> SQLContext: ... - def setConf(self, key: str, value) -> None: ... + def setConf(self, key: str, value: Union[bool, int, str]) -> None: ... def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ... @property def udf(self) -> UDFRegistration: ... @@ -116,7 +116,7 @@ class SQLContext: path: Optional[str] = ..., source: Optional[str] = ..., schema: Optional[StructType] = ..., - **options + **options: str ) -> DataFrame: ... def sql(self, sqlQuery: str) -> DataFrame: ... def table(self, tableName: str) -> DataFrame: ... diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 281c1d7543..252f883b5f 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -65,13 +65,13 @@ def round(col: ColumnOrName, scale: int = ...) -> Column: ... def bround(col: ColumnOrName, scale: int = ...) -> Column: ... def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ... def shiftRight(col: ColumnOrName, numBits: int) -> Column: ... -def shiftRightUnsigned(col, numBits) -> Column: ... +def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ... def spark_partition_id() -> Column: ... def expr(str: str) -> Column: ... def struct(*cols: ColumnOrName) -> Column: ... def greatest(*cols: ColumnOrName) -> Column: ... def least(*cols: Column) -> Column: ... -def when(condition: Column, value) -> Column: ... +def when(condition: Column, value: Any) -> Column: ... @overload def log(arg1: ColumnOrName) -> Column: ... @overload @@ -174,7 +174,9 @@ def create_map(*cols: ColumnOrName) -> Column: ... def array(*cols: ColumnOrName) -> Column: ... def array_contains(col: ColumnOrName, value: Any) -> Column: ... def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ... -def slice(x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]) -> Column: ... +def slice( + x: ColumnOrName, start: Union[Column, int], length: Union[Column, int] +) -> Column: ... def array_join( col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ... ) -> Column: ... diff --git a/python/pyspark/sql/session.pyi b/python/pyspark/sql/session.pyi index 17ba8894c1..6cd2d3bed2 100644 --- a/python/pyspark/sql/session.pyi +++ b/python/pyspark/sql/session.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from types import TracebackType from py4j.java_gateway import JavaObject # type: ignore[import] @@ -122,4 +123,9 @@ class SparkSession(SparkConversionMixin): def streams(self) -> StreamingQueryManager: ... def stop(self) -> None: ... def __enter__(self) -> SparkSession: ... - def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi index 31765e9488..3adf823d99 100644 --- a/python/pyspark/sql/types.pyi +++ b/python/pyspark/sql/types.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar +from py4j.java_gateway import JavaGateway, JavaObject import datetime T = TypeVar("T") @@ -37,7 +38,7 @@ class DataType: def fromInternal(self, obj: Any) -> Any: ... class DataTypeSingleton(type): - def __call__(cls): ... + def __call__(cls: Type[T]) -> T: ... # type: ignore class NullType(DataType, metaclass=DataTypeSingleton): ... class AtomicType(DataType): ... @@ -85,8 +86,8 @@ class ShortType(IntegralType): class ArrayType(DataType): elementType: DataType containsNull: bool - def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ... - def simpleString(self): ... + def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ... + def simpleString(self) -> str: ... def jsonValue(self) -> Dict[str, Any]: ... @classmethod def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ... @@ -197,8 +198,8 @@ class Row(tuple): class DateConverter: def can_convert(self, obj: Any) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ... class DatetimeConverter: - def can_convert(self, obj) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def can_convert(self, obj: Any) -> bool: ... + def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ... diff --git a/python/pyspark/sql/udf.pyi b/python/pyspark/sql/udf.pyi index 87c3672780..ea61397a67 100644 --- a/python/pyspark/sql/udf.pyi +++ b/python/pyspark/sql/udf.pyi @@ -18,8 +18,9 @@ from typing import Any, Callable, Optional -from pyspark.sql._typing import ColumnOrName, DataTypeOrString +from pyspark.sql._typing import ColumnOrName, DataTypeOrString, UserDefinedFunctionLike from pyspark.sql.column import Column +from pyspark.sql.types import DataType import pyspark.sql.session class UserDefinedFunction: @@ -35,7 +36,7 @@ class UserDefinedFunction: deterministic: bool = ..., ) -> None: ... @property - def returnType(self): ... + def returnType(self) -> DataType: ... def __call__(self, *cols: ColumnOrName) -> Column: ... def asNondeterministic(self) -> UserDefinedFunction: ... @@ -47,7 +48,7 @@ class UDFRegistration: name: str, f: Callable[..., Any], returnType: Optional[DataTypeOrString] = ..., - ): ... + ) -> UserDefinedFunctionLike: ... def registerJavaFunction( self, name: str, diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi index 026163fc9a..117a6742e6 100644 --- a/python/pyspark/streaming/context.pyi +++ b/python/pyspark/streaming/context.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar from py4j.java_gateway import JavaObject # type: ignore[import] diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi index 7b76ce4c65..1521d838fc 100644 --- a/python/pyspark/streaming/dstream.pyi +++ b/python/pyspark/streaming/dstream.pyi @@ -30,9 +30,12 @@ from typing import ( ) import datetime from pyspark.rdd import RDD +import pyspark.serializers from pyspark.storagelevel import StorageLevel import pyspark.streaming.context +from py4j.java_gateway import JavaObject + S = TypeVar("S") T = TypeVar("T") U = TypeVar("U") @@ -42,7 +45,12 @@ V = TypeVar("V") class DStream(Generic[T]): is_cached: bool is_checkpointed: bool - def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ... + def __init__( + self, + jdstream: JavaObject, + ssc: pyspark.streaming.context.StreamingContext, + jrdd_deserializer: pyspark.serializers.Serializer, + ) -> None: ... def context(self) -> pyspark.streaming.context.StreamingContext: ... def count(self) -> DStream[int]: ... def filter(self, f: Callable[[T], bool]) -> DStream[T]: ... diff --git a/python/pyspark/streaming/kinesis.pyi b/python/pyspark/streaming/kinesis.pyi index af7cd6f6ec..399c37f869 100644 --- a/python/pyspark/streaming/kinesis.pyi +++ b/python/pyspark/streaming/kinesis.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Optional, TypeVar +from typing import Callable, Optional, TypeVar from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream