diff --git a/python/mypy.ini b/python/mypy.ini index 4a5368a519..5103452a05 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -16,10 +16,97 @@ ; [mypy] +strict_optional = True +no_implicit_optional = True +disallow_untyped_defs = True + +; Allow untyped def in internal modules and tests + +[mypy-pyspark.daemon] +disallow_untyped_defs = False + +[mypy-pyspark.find_spark_home] +disallow_untyped_defs = False + +[mypy-pyspark._globals] +disallow_untyped_defs = False + +[mypy-pyspark.install] +disallow_untyped_defs = False + +[mypy-pyspark.java_gateway] +disallow_untyped_defs = False + +[mypy-pyspark.join] +disallow_untyped_defs = False + +[mypy-pyspark.ml.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.mllib.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.rddsampler] +disallow_untyped_defs = False + +[mypy-pyspark.resource.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.shuffle] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.util] +disallow_untyped_defs = False + +[mypy-pyspark.sql.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.types] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.typehints] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.utils] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas._typing.protocols.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.utils] +disallow_untyped_defs = False + +[mypy-pyspark.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.testing.*] +disallow_untyped_defs = False + +[mypy-pyspark.traceback_utils] +disallow_untyped_defs = False + +[mypy-pyspark.util] +disallow_untyped_defs = False + +[mypy-pyspark.worker] +disallow_untyped_defs = False + +; Ignore errors in embedded third party code [mypy-pyspark.cloudpickle.*] ignore_errors = True +; Ignore missing imports for external untyped packages + [mypy-py4j.*] ignore_missing_imports = True diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi index 4b019a509a..944cb06d41 100644 --- a/python/pyspark/broadcast.pyi +++ b/python/pyspark/broadcast.pyi @@ -17,7 +17,7 @@ # under the License. import threading -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar T = TypeVar("T") @@ -32,14 +32,14 @@ class Broadcast(Generic[T]): path: Optional[Any] = ..., sock_file: Optional[Any] = ..., ) -> None: ... - def dump(self, value: Any, f: Any) -> None: ... - def load_from_path(self, path: Any): ... - def load(self, file: Any): ... + def dump(self, value: T, f: Any) -> None: ... + def load_from_path(self, path: Any) -> T: ... + def load(self, file: Any) -> T: ... @property def value(self) -> T: ... def unpersist(self, blocking: bool = ...) -> None: ... def destroy(self, blocking: bool = ...) -> None: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ... class BroadcastPickleRegistry(threading.local): def __init__(self) -> None: ... diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi index 2789a38b3b..640a69cad0 100644 --- a/python/pyspark/context.pyi +++ b/python/pyspark/context.pyi @@ -16,7 +16,19 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NoReturn, + Optional, + Tuple, + Type, + TypeVar, +) +from types import TracebackType from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import] @@ -51,9 +63,14 @@ class SparkContext: jsc: Optional[JavaObject] = ..., profiler_cls: type = ..., ) -> None: ... - def __getnewargs__(self): ... - def __enter__(self): ... - def __exit__(self, type, value, trace): ... + def __getnewargs__(self) -> NoReturn: ... + def __enter__(self) -> SparkContext: ... + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + trace: Optional[TracebackType], + ) -> None: ... @classmethod def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ... def setLogLevel(self, logLevel: str) -> None: ... diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi index 4bde851bb1..c44176a13a 100644 --- a/python/pyspark/ml/classification.pyi +++ b/python/pyspark/ml/classification.pyi @@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier( class _JavaProbabilisticClassificationModel( ProbabilisticClassificationModel, _JavaClassificationModel[T] ): - def predictProbability(self, value: Any): ... + def predictProbability(self, value: Vector) -> Vector: ... class _ClassificationSummary(JavaWrapper): @property @@ -543,7 +543,7 @@ class RandomForestClassificationModel( @property def trees(self) -> List[DecisionTreeClassificationModel]: ... def summary(self) -> RandomForestClassificationTrainingSummary: ... - def evaluate(self, dataset) -> RandomForestClassificationSummary: ... + def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ... class RandomForestClassificationSummary(_ClassificationSummary): ... class RandomForestClassificationTrainingSummary( @@ -891,7 +891,7 @@ class FMClassifier( solver: str = ..., thresholds: Optional[Any] = ..., seed: Optional[Any] = ..., - ): ... + ) -> FMClassifier: ... def setFactorSize(self, value: int) -> FMClassifier: ... def setFitLinear(self, value: bool) -> FMClassifier: ... def setMiniBatchFraction(self, value: float) -> FMClassifier: ... diff --git a/python/pyspark/ml/common.pyi b/python/pyspark/ml/common.pyi index 7bf0ed6183..a38fc5734f 100644 --- a/python/pyspark/ml/common.pyi +++ b/python/pyspark/ml/common.pyi @@ -16,5 +16,11 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def inherit_doc(cls): ... +from typing import Any, TypeVar + +import pyspark.context + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi index ea0a9f045c..55a3ae2774 100644 --- a/python/pyspark/ml/evaluation.pyi +++ b/python/pyspark/ml/evaluation.pyi @@ -39,9 +39,12 @@ from pyspark.ml.param.shared import ( HasWeightCol, ) from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.sql.dataframe import DataFrame class Evaluator(Params, metaclass=abc.ABCMeta): - def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ... + def evaluate( + self, dataset: DataFrame, params: Optional[ParamMap] = ... + ) -> float: ... def isLargerBetter(self) -> bool: ... class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta): @@ -75,16 +78,15 @@ class BinaryClassificationEvaluator( def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ... def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ... def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ... - -def setParams( - self, - *, - rawPredictionCol: str = ..., - labelCol: str = ..., - metricName: BinaryClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - numBins: int = ... -) -> BinaryClassificationEvaluator: ... + def setParams( + self, + *, + rawPredictionCol: str = ..., + labelCol: str = ..., + metricName: BinaryClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + numBins: int = ... + ) -> BinaryClassificationEvaluator: ... class RegressionEvaluator( JavaEvaluator, diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index f5b12a5b2f..4999defdf8 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol): def getNumHashTables(self) -> int: ... class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable): - def setNumHashTables(self: P, value) -> P: ... - def setInputCol(self: P, value) -> P: ... - def setOutputCol(self: P, value) -> P: ... + def setNumHashTables(self: P, value: int) -> P: ... + def setInputCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... class _LSHModel(JavaModel, _LSHParams): def setInputCol(self: P, value: str) -> P: ... @@ -1518,7 +1518,7 @@ class ChiSqSelector( fpr: float = ..., fdr: float = ..., fwe: float = ... - ): ... + ) -> ChiSqSelector: ... def setSelectorType(self, value: str) -> ChiSqSelector: ... def setNumTopFeatures(self, value: int) -> ChiSqSelector: ... def setPercentile(self, value: float) -> ChiSqSelector: ... @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol): def getVarianceThreshold(self) -> float: ... class VarianceThresholdSelector( - JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaEstimator[VarianceThresholdSelectorModel], + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelector], + JavaMLWritable, ): def __init__( self, @@ -1615,13 +1618,16 @@ class VarianceThresholdSelector( featuresCol: str = ..., outputCol: Optional[str] = ..., varianceThreshold: float = ..., - ): ... + ) -> VarianceThresholdSelector: ... def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ... def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ... def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... class VarianceThresholdSelectorModel( - JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaModel, + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelectorModel], + JavaMLWritable, ): def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ... def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ... diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi index a576b30aec..b4fba8823b 100644 --- a/python/pyspark/ml/linalg/__init__.pyi +++ b/python/pyspark/ml/linalg/__init__.pyi @@ -17,7 +17,7 @@ # under the License. from typing import overload -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union from pyspark.ml import linalg as newlinalg # noqa: F401 from pyspark.sql.types import StructType, UserDefinedType @@ -45,7 +45,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -64,9 +64,7 @@ class DenseVector(Vector): def __init__(self, __arr: bytes) -> None: ... @overload def __init__(self, __arr: Iterable[float]) -> None: ... - @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -112,16 +110,14 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... - @staticmethod - def parse(s: str) -> SparseVector: ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... def dot(self, other: Iterable[float]) -> float64: ... def squared_distance(self, other: Iterable[float]) -> float64: ... def toArray(self) -> ndarray: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -144,13 +140,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def stringify(vector: Vector) -> str: ... @staticmethod @@ -158,8 +154,6 @@ class Vectors: @staticmethod def norm(vector: Vector, p: Union[float, str]) -> float64: ... @staticmethod - def parse(s: str) -> Vector: ... - @staticmethod def zeros(size: int) -> DenseVector: ... class Matrix: @@ -170,7 +164,7 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... + def toArray(self) -> NoReturn: ... class DenseMatrix(Matrix): values: Any @@ -186,11 +180,11 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -216,11 +210,13 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi index 44680586d7..f47e9e012a 100644 --- a/python/pyspark/ml/pipeline.pyi +++ b/python/pyspark/ml/pipeline.pyi @@ -51,7 +51,7 @@ class PipelineWriter(MLWriter): def __init__(self, instance: Pipeline) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineReader(MLReader): +class PipelineReader(MLReader[Pipeline]): cls: Type[Pipeline] def __init__(self, cls: Type[Pipeline]) -> None: ... def load(self, path: str) -> Pipeline: ... @@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter): def __init__(self, instance: PipelineModel) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineModelReader(MLReader): +class PipelineModelReader(MLReader[PipelineModel]): cls: Type[PipelineModel] def __init__(self, cls: Type[PipelineModel]) -> None: ... def load(self, path: str) -> PipelineModel: ... diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi index 5cb0e7a509..b8f1e61859 100644 --- a/python/pyspark/ml/regression.pyi +++ b/python/pyspark/ml/regression.pyi @@ -414,7 +414,7 @@ class RandomForestRegressionModel( _TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable[RandomForestRegressionModel], ): @property def trees(self) -> List[DecisionTreeRegressionModel]: ... @@ -749,10 +749,10 @@ class _FactorizationMachinesParams( initStd: Param[float] solver: Param[str] def __init__(self, *args: Any): ... - def getFactorSize(self): ... - def getFitLinear(self): ... - def getMiniBatchFraction(self): ... - def getInitStd(self): ... + def getFactorSize(self) -> int: ... + def getFitLinear(self) -> bool: ... + def getMiniBatchFraction(self) -> float: ... + def getInitStd(self) -> float: ... class FMRegressor( _JavaRegressor[FMRegressionModel], diff --git a/python/pyspark/mllib/classification.pyi b/python/pyspark/mllib/classification.pyi index c51882c87b..967b0a9f28 100644 --- a/python/pyspark/mllib/classification.pyi +++ b/python/pyspark/mllib/classification.pyi @@ -118,7 +118,7 @@ class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]): labels: ndarray pi: ndarray theta: ndarray - def __init__(self, labels, pi, theta) -> None: ... + def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ... @overload def predict(self, x: VectorLike) -> float64: ... @overload diff --git a/python/pyspark/mllib/clustering.pyi b/python/pyspark/mllib/clustering.pyi index 1c3eba17e2..b4f349612f 100644 --- a/python/pyspark/mllib/clustering.pyi +++ b/python/pyspark/mllib/clustering.pyi @@ -63,7 +63,7 @@ class BisectingKMeans: class KMeansModel(Saveable, Loader[KMeansModel]): centers: List[ndarray] - def __init__(self, centers: List[ndarray]) -> None: ... + def __init__(self, centers: List[VectorLike]) -> None: ... @property def clusterCenters(self) -> List[ndarray]: ... @property @@ -144,7 +144,9 @@ class PowerIterationClustering: class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ... class StreamingKMeansModel(KMeansModel): - def __init__(self, clusterCenters, clusterWeights) -> None: ... + def __init__( + self, clusterCenters: List[VectorLike], clusterWeights: VectorLike + ) -> None: ... @property def clusterWeights(self) -> List[float64]: ... centers: ndarray diff --git a/python/pyspark/mllib/common.pyi b/python/pyspark/mllib/common.pyi index 1df308b91b..daba212d93 100644 --- a/python/pyspark/mllib/common.pyi +++ b/python/pyspark/mllib/common.pyi @@ -16,12 +16,20 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def callMLlibFunc(name, *args): ... +from typing import Any, TypeVar + +import pyspark.context + +from py4j.java_gateway import JavaObject + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def callMLlibFunc(name: str, *args: Any) -> Any: ... class JavaModelWrapper: - def __init__(self, java_model) -> None: ... - def __del__(self): ... - def call(self, name, *a): ... + def __init__(self, java_model: JavaObject) -> None: ... + def __del__(self) -> None: ... + def call(self, name: str, *a: Any) -> Any: ... -def inherit_doc(cls): ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/mllib/linalg/__init__.pyi b/python/pyspark/mllib/linalg/__init__.pyi index c0719c535c..60d16b26f3 100644 --- a/python/pyspark/mllib/linalg/__init__.pyi +++ b/python/pyspark/mllib/linalg/__init__.pyi @@ -17,7 +17,18 @@ # under the License. from typing import overload -from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from pyspark.ml import linalg as newlinalg from pyspark.sql.types import StructType, UserDefinedType from numpy import float64, ndarray # type: ignore[import] @@ -46,7 +57,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -67,8 +78,8 @@ class DenseVector(Vector): @overload def __init__(self, __arr: Iterable[float]) -> None: ... @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def parse(s: str) -> DenseVector: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -115,7 +126,7 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... @staticmethod def parse(s: str) -> SparseVector: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -123,9 +134,9 @@ class SparseVector(Vector): def toArray(self) -> ndarray: ... def asML(self) -> newlinalg.SparseVector: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -148,13 +159,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def fromML(vec: newlinalg.DenseVector) -> DenseVector: ... @staticmethod @@ -176,8 +187,8 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... - def asML(self): ... + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.Matrix: ... class DenseMatrix(Matrix): values: Any @@ -193,12 +204,12 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def asML(self) -> newlinalg.DenseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -224,12 +235,14 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... def asML(self) -> newlinalg.SparseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload diff --git a/python/pyspark/mllib/random.pyi b/python/pyspark/mllib/random.pyi index dc5f470161..ec83170625 100644 --- a/python/pyspark/mllib/random.pyi +++ b/python/pyspark/mllib/random.pyi @@ -90,7 +90,7 @@ class RandomRDDs: def logNormalVectorRDD( sc: SparkContext, mean: float, - std, + std: float, numRows: int, numCols: int, numPartitions: Optional[int] = ..., diff --git a/python/pyspark/mllib/recommendation.pyi b/python/pyspark/mllib/recommendation.pyi index e2f1549420..4fea0acf3c 100644 --- a/python/pyspark/mllib/recommendation.pyi +++ b/python/pyspark/mllib/recommendation.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import array from collections import namedtuple @@ -27,7 +27,7 @@ from pyspark.mllib.common import JavaModelWrapper from pyspark.mllib.util import JavaLoader, JavaSaveable class Rating(namedtuple("Rating", ["user", "product", "rating"])): - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[Rating], Tuple[int, int, float]]: ... class MatrixFactorizationModel( JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel] diff --git a/python/pyspark/mllib/stat/_statistics.pyi b/python/pyspark/mllib/stat/_statistics.pyi index 4d2701d486..3834d51639 100644 --- a/python/pyspark/mllib/stat/_statistics.pyi +++ b/python/pyspark/mllib/stat/_statistics.pyi @@ -65,5 +65,5 @@ class Statistics: def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ... @staticmethod def kolmogorovSmirnovTest( - data, distName: Literal["norm"] = ..., *params: float + data: RDD[float], distName: Literal["norm"] = ..., *params: float ) -> KolmogorovSmirnovTestResult: ... diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index 35c49e952b..a277cd9f7e 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -85,12 +85,16 @@ class PythonEvalType: SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType class BoundedFloat(float): - def __new__(cls, mean: float, confidence: float, low: float, high: float): ... + def __new__( + cls, mean: float, confidence: float, low: float, high: float + ) -> BoundedFloat: ... class Partitioner: numPartitions: int partitionFunc: Callable[[Any], int] - def __init__(self, numPartitions, partitionFunc) -> None: ... + def __init__( + self, numPartitions: int, partitionFunc: Callable[[Any], int] + ) -> None: ... def __eq__(self, other: Any) -> bool: ... def __call__(self, k: Any) -> int: ... diff --git a/python/pyspark/resource/profile.pyi b/python/pyspark/resource/profile.pyi index 6763baf659..0483869243 100644 --- a/python/pyspark/resource/profile.pyi +++ b/python/pyspark/resource/profile.pyi @@ -49,7 +49,7 @@ class ResourceProfileBuilder: def __init__(self) -> None: ... def require( self, resourceRequest: Union[ExecutorResourceRequest, TaskResourceRequests] - ): ... + ) -> ResourceProfileBuilder: ... def clearExecutorResourceRequests(self) -> None: ... def clearTaskResourceRequests(self) -> None: ... @property diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi index 0fbb10053f..1f63e65b3d 100644 --- a/python/pyspark/sql/column.pyi +++ b/python/pyspark/sql/column.pyi @@ -32,7 +32,7 @@ from pyspark.sql.window import WindowSpec from py4j.java_gateway import JavaObject # type: ignore[import] class Column: - def __init__(self, JavaObject) -> None: ... + def __init__(self, jc: JavaObject) -> None: ... def __neg__(self) -> Column: ... def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... @@ -105,7 +105,11 @@ class Column: def name(self, *alias: str) -> Column: ... def cast(self, dataType: Union[DataType, str]) -> Column: ... def astype(self, dataType: Union[DataType, str]) -> Column: ... - def between(self, lowerBound, upperBound) -> Column: ... + def between( + self, + lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + ) -> Column: ... def when(self, condition: Column, value: Any) -> Column: ... def otherwise(self, value: Any) -> Column: ... def over(self, window: WindowSpec) -> Column: ... diff --git a/python/pyspark/sql/context.pyi b/python/pyspark/sql/context.pyi index 64927b37ac..915a0fe1f6 100644 --- a/python/pyspark/sql/context.pyi +++ b/python/pyspark/sql/context.pyi @@ -43,14 +43,14 @@ class SQLContext: sparkSession: SparkSession def __init__( self, - sparkContext, + sparkContext: SparkContext, sparkSession: Optional[SparkSession] = ..., jsqlContext: Optional[JavaObject] = ..., ) -> None: ... @classmethod def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ... def newSession(self) -> SQLContext: ... - def setConf(self, key: str, value) -> None: ... + def setConf(self, key: str, value: Union[bool, int, str]) -> None: ... def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ... @property def udf(self) -> UDFRegistration: ... @@ -116,7 +116,7 @@ class SQLContext: path: Optional[str] = ..., source: Optional[str] = ..., schema: Optional[StructType] = ..., - **options + **options: str ) -> DataFrame: ... def sql(self, sqlQuery: str) -> DataFrame: ... def table(self, tableName: str) -> DataFrame: ... diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 281c1d7543..252f883b5f 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -65,13 +65,13 @@ def round(col: ColumnOrName, scale: int = ...) -> Column: ... def bround(col: ColumnOrName, scale: int = ...) -> Column: ... def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ... def shiftRight(col: ColumnOrName, numBits: int) -> Column: ... -def shiftRightUnsigned(col, numBits) -> Column: ... +def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ... def spark_partition_id() -> Column: ... def expr(str: str) -> Column: ... def struct(*cols: ColumnOrName) -> Column: ... def greatest(*cols: ColumnOrName) -> Column: ... def least(*cols: Column) -> Column: ... -def when(condition: Column, value) -> Column: ... +def when(condition: Column, value: Any) -> Column: ... @overload def log(arg1: ColumnOrName) -> Column: ... @overload @@ -174,7 +174,9 @@ def create_map(*cols: ColumnOrName) -> Column: ... def array(*cols: ColumnOrName) -> Column: ... def array_contains(col: ColumnOrName, value: Any) -> Column: ... def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ... -def slice(x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]) -> Column: ... +def slice( + x: ColumnOrName, start: Union[Column, int], length: Union[Column, int] +) -> Column: ... def array_join( col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ... ) -> Column: ... diff --git a/python/pyspark/sql/session.pyi b/python/pyspark/sql/session.pyi index 17ba8894c1..6cd2d3bed2 100644 --- a/python/pyspark/sql/session.pyi +++ b/python/pyspark/sql/session.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from types import TracebackType from py4j.java_gateway import JavaObject # type: ignore[import] @@ -122,4 +123,9 @@ class SparkSession(SparkConversionMixin): def streams(self) -> StreamingQueryManager: ... def stop(self) -> None: ... def __enter__(self) -> SparkSession: ... - def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi index 31765e9488..3adf823d99 100644 --- a/python/pyspark/sql/types.pyi +++ b/python/pyspark/sql/types.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar +from py4j.java_gateway import JavaGateway, JavaObject import datetime T = TypeVar("T") @@ -37,7 +38,7 @@ class DataType: def fromInternal(self, obj: Any) -> Any: ... class DataTypeSingleton(type): - def __call__(cls): ... + def __call__(cls: Type[T]) -> T: ... # type: ignore class NullType(DataType, metaclass=DataTypeSingleton): ... class AtomicType(DataType): ... @@ -85,8 +86,8 @@ class ShortType(IntegralType): class ArrayType(DataType): elementType: DataType containsNull: bool - def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ... - def simpleString(self): ... + def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ... + def simpleString(self) -> str: ... def jsonValue(self) -> Dict[str, Any]: ... @classmethod def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ... @@ -197,8 +198,8 @@ class Row(tuple): class DateConverter: def can_convert(self, obj: Any) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ... class DatetimeConverter: - def can_convert(self, obj) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def can_convert(self, obj: Any) -> bool: ... + def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ... diff --git a/python/pyspark/sql/udf.pyi b/python/pyspark/sql/udf.pyi index 87c3672780..ea61397a67 100644 --- a/python/pyspark/sql/udf.pyi +++ b/python/pyspark/sql/udf.pyi @@ -18,8 +18,9 @@ from typing import Any, Callable, Optional -from pyspark.sql._typing import ColumnOrName, DataTypeOrString +from pyspark.sql._typing import ColumnOrName, DataTypeOrString, UserDefinedFunctionLike from pyspark.sql.column import Column +from pyspark.sql.types import DataType import pyspark.sql.session class UserDefinedFunction: @@ -35,7 +36,7 @@ class UserDefinedFunction: deterministic: bool = ..., ) -> None: ... @property - def returnType(self): ... + def returnType(self) -> DataType: ... def __call__(self, *cols: ColumnOrName) -> Column: ... def asNondeterministic(self) -> UserDefinedFunction: ... @@ -47,7 +48,7 @@ class UDFRegistration: name: str, f: Callable[..., Any], returnType: Optional[DataTypeOrString] = ..., - ): ... + ) -> UserDefinedFunctionLike: ... def registerJavaFunction( self, name: str, diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi index 026163fc9a..117a6742e6 100644 --- a/python/pyspark/streaming/context.pyi +++ b/python/pyspark/streaming/context.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar from py4j.java_gateway import JavaObject # type: ignore[import] diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi index 7b76ce4c65..1521d838fc 100644 --- a/python/pyspark/streaming/dstream.pyi +++ b/python/pyspark/streaming/dstream.pyi @@ -30,9 +30,12 @@ from typing import ( ) import datetime from pyspark.rdd import RDD +import pyspark.serializers from pyspark.storagelevel import StorageLevel import pyspark.streaming.context +from py4j.java_gateway import JavaObject + S = TypeVar("S") T = TypeVar("T") U = TypeVar("U") @@ -42,7 +45,12 @@ V = TypeVar("V") class DStream(Generic[T]): is_cached: bool is_checkpointed: bool - def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ... + def __init__( + self, + jdstream: JavaObject, + ssc: pyspark.streaming.context.StreamingContext, + jrdd_deserializer: pyspark.serializers.Serializer, + ) -> None: ... def context(self) -> pyspark.streaming.context.StreamingContext: ... def count(self) -> DStream[int]: ... def filter(self, f: Callable[[T], bool]) -> DStream[T]: ... diff --git a/python/pyspark/streaming/kinesis.pyi b/python/pyspark/streaming/kinesis.pyi index af7cd6f6ec..399c37f869 100644 --- a/python/pyspark/streaming/kinesis.pyi +++ b/python/pyspark/streaming/kinesis.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Optional, TypeVar +from typing import Callable, Optional, TypeVar from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream