[SPARK-33457][PYTHON] Adjust mypy configuration

### What changes were proposed in this pull request?

This pull request:

- Adds following flags to the main mypy configuration:
  - [`strict_optional`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-strict_optional)
  - [`no_implicit_optional`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-no_implicit_optional)
  - [`disallow_untyped_defs`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-disallow_untyped_calls)

These flags are enabled only for public API and disabled for tests and internal modules.

Additionally, these PR fixes missing annotations.

### Why are the changes needed?

Primary reason to propose this changes is to use standard configuration as used by typeshed project. This will allow us to be more strict, especially when interacting with JVM code. See for example https://github.com/apache/spark/pull/29122#pullrequestreview-513112882

Additionally, it will allow us to detect cases where annotations have unintentionally omitted.

### Does this PR introduce _any_ user-facing change?

Annotations only.

### How was this patch tested?

`dev/lint-python`.

Closes #30382 from zero323/SPARK-33457.

Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
zero323 2020-11-25 09:27:04 +09:00 committed by HyukjinKwon
parent 95b6dabc33
commit 665817bd4f
28 changed files with 277 additions and 114 deletions

View file

@ -16,10 +16,97 @@
;
[mypy]
strict_optional = True
no_implicit_optional = True
disallow_untyped_defs = True
; Allow untyped def in internal modules and tests
[mypy-pyspark.daemon]
disallow_untyped_defs = False
[mypy-pyspark.find_spark_home]
disallow_untyped_defs = False
[mypy-pyspark._globals]
disallow_untyped_defs = False
[mypy-pyspark.install]
disallow_untyped_defs = False
[mypy-pyspark.java_gateway]
disallow_untyped_defs = False
[mypy-pyspark.join]
disallow_untyped_defs = False
[mypy-pyspark.ml.tests.*]
disallow_untyped_defs = False
[mypy-pyspark.mllib.tests.*]
disallow_untyped_defs = False
[mypy-pyspark.rddsampler]
disallow_untyped_defs = False
[mypy-pyspark.resource.tests.*]
disallow_untyped_defs = False
[mypy-pyspark.serializers]
disallow_untyped_defs = False
[mypy-pyspark.shuffle]
disallow_untyped_defs = False
[mypy-pyspark.streaming.tests.*]
disallow_untyped_defs = False
[mypy-pyspark.streaming.util]
disallow_untyped_defs = False
[mypy-pyspark.sql.tests.*]
disallow_untyped_defs = False
[mypy-pyspark.sql.pandas.serializers]
disallow_untyped_defs = False
[mypy-pyspark.sql.pandas.types]
disallow_untyped_defs = False
[mypy-pyspark.sql.pandas.typehints]
disallow_untyped_defs = False
[mypy-pyspark.sql.pandas.utils]
disallow_untyped_defs = False
[mypy-pyspark.sql.pandas._typing.protocols.*]
disallow_untyped_defs = False
[mypy-pyspark.sql.utils]
disallow_untyped_defs = False
[mypy-pyspark.tests.*]
disallow_untyped_defs = False
[mypy-pyspark.testing.*]
disallow_untyped_defs = False
[mypy-pyspark.traceback_utils]
disallow_untyped_defs = False
[mypy-pyspark.util]
disallow_untyped_defs = False
[mypy-pyspark.worker]
disallow_untyped_defs = False
; Ignore errors in embedded third party code
[mypy-pyspark.cloudpickle.*]
ignore_errors = True
; Ignore missing imports for external untyped packages
[mypy-py4j.*]
ignore_missing_imports = True

View file

@ -17,7 +17,7 @@
# under the License.
import threading
from typing import Any, Dict, Generic, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
T = TypeVar("T")
@ -32,14 +32,14 @@ class Broadcast(Generic[T]):
path: Optional[Any] = ...,
sock_file: Optional[Any] = ...,
) -> None: ...
def dump(self, value: Any, f: Any) -> None: ...
def load_from_path(self, path: Any): ...
def load(self, file: Any): ...
def dump(self, value: T, f: Any) -> None: ...
def load_from_path(self, path: Any) -> T: ...
def load(self, file: Any) -> T: ...
@property
def value(self) -> T: ...
def unpersist(self, blocking: bool = ...) -> None: ...
def destroy(self, blocking: bool = ...) -> None: ...
def __reduce__(self): ...
def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ...
class BroadcastPickleRegistry(threading.local):
def __init__(self) -> None: ...

View file

@ -16,7 +16,19 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Optional,
Tuple,
Type,
TypeVar,
)
from types import TracebackType
from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import]
@ -51,9 +63,14 @@ class SparkContext:
jsc: Optional[JavaObject] = ...,
profiler_cls: type = ...,
) -> None: ...
def __getnewargs__(self): ...
def __enter__(self): ...
def __exit__(self, type, value, trace): ...
def __getnewargs__(self) -> NoReturn: ...
def __enter__(self) -> SparkContext: ...
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
trace: Optional[TracebackType],
) -> None: ...
@classmethod
def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ...
def setLogLevel(self, logLevel: str) -> None: ...

View file

@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier(
class _JavaProbabilisticClassificationModel(
ProbabilisticClassificationModel, _JavaClassificationModel[T]
):
def predictProbability(self, value: Any): ...
def predictProbability(self, value: Vector) -> Vector: ...
class _ClassificationSummary(JavaWrapper):
@property
@ -543,7 +543,7 @@ class RandomForestClassificationModel(
@property
def trees(self) -> List[DecisionTreeClassificationModel]: ...
def summary(self) -> RandomForestClassificationTrainingSummary: ...
def evaluate(self, dataset) -> RandomForestClassificationSummary: ...
def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ...
class RandomForestClassificationSummary(_ClassificationSummary): ...
class RandomForestClassificationTrainingSummary(
@ -891,7 +891,7 @@ class FMClassifier(
solver: str = ...,
thresholds: Optional[Any] = ...,
seed: Optional[Any] = ...,
): ...
) -> FMClassifier: ...
def setFactorSize(self, value: int) -> FMClassifier: ...
def setFitLinear(self, value: bool) -> FMClassifier: ...
def setMiniBatchFraction(self, value: float) -> FMClassifier: ...

View file

@ -16,5 +16,11 @@
# specific language governing permissions and limitations
# under the License.
def callJavaFunc(sc, func, *args): ...
def inherit_doc(cls): ...
from typing import Any, TypeVar
import pyspark.context
C = TypeVar("C", bound=type)
def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ...
def inherit_doc(cls: C) -> C: ...

View file

@ -39,9 +39,12 @@ from pyspark.ml.param.shared import (
HasWeightCol,
)
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.sql.dataframe import DataFrame
class Evaluator(Params, metaclass=abc.ABCMeta):
def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ...
def evaluate(
self, dataset: DataFrame, params: Optional[ParamMap] = ...
) -> float: ...
def isLargerBetter(self) -> bool: ...
class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta):
@ -75,7 +78,6 @@ class BinaryClassificationEvaluator(
def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ...
def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ...
def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ...
def setParams(
self,
*,

View file

@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol):
def getNumHashTables(self) -> int: ...
class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable):
def setNumHashTables(self: P, value) -> P: ...
def setInputCol(self: P, value) -> P: ...
def setOutputCol(self: P, value) -> P: ...
def setNumHashTables(self: P, value: int) -> P: ...
def setInputCol(self: P, value: str) -> P: ...
def setOutputCol(self: P, value: str) -> P: ...
class _LSHModel(JavaModel, _LSHParams):
def setInputCol(self: P, value: str) -> P: ...
@ -1518,7 +1518,7 @@ class ChiSqSelector(
fpr: float = ...,
fdr: float = ...,
fwe: float = ...
): ...
) -> ChiSqSelector: ...
def setSelectorType(self, value: str) -> ChiSqSelector: ...
def setNumTopFeatures(self, value: int) -> ChiSqSelector: ...
def setPercentile(self, value: float) -> ChiSqSelector: ...
@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
def getVarianceThreshold(self) -> float: ...
class VarianceThresholdSelector(
JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
JavaEstimator[VarianceThresholdSelectorModel],
_VarianceThresholdSelectorParams,
JavaMLReadable[VarianceThresholdSelector],
JavaMLWritable,
):
def __init__(
self,
@ -1615,13 +1618,16 @@ class VarianceThresholdSelector(
featuresCol: str = ...,
outputCol: Optional[str] = ...,
varianceThreshold: float = ...,
): ...
) -> VarianceThresholdSelector: ...
def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ...
def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ...
def setOutputCol(self, value: str) -> VarianceThresholdSelector: ...
class VarianceThresholdSelectorModel(
JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
JavaModel,
_VarianceThresholdSelectorParams,
JavaMLReadable[VarianceThresholdSelectorModel],
JavaMLWritable,
):
def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ...
def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ...

View file

@ -17,7 +17,7 @@
# under the License.
from typing import overload
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union
from pyspark.ml import linalg as newlinalg # noqa: F401
from pyspark.sql.types import StructType, UserDefinedType
@ -45,7 +45,7 @@ class MatrixUDT(UserDefinedType):
@classmethod
def scalaUDT(cls) -> str: ...
def serialize(
self, obj
self, obj: Matrix
) -> Tuple[
int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool
]: ...
@ -64,9 +64,7 @@ class DenseVector(Vector):
def __init__(self, __arr: bytes) -> None: ...
@overload
def __init__(self, __arr: Iterable[float]) -> None: ...
@staticmethod
def parse(s) -> DenseVector: ...
def __reduce__(self) -> Tuple[type, bytes]: ...
def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ...
def numNonzeros(self) -> int: ...
def norm(self, p: Union[float, str]) -> float64: ...
def dot(self, other: Iterable[float]) -> float64: ...
@ -112,16 +110,14 @@ class SparseVector(Vector):
def __init__(self, size: int, __map: Dict[int, float]) -> None: ...
def numNonzeros(self) -> int: ...
def norm(self, p: Union[float, str]) -> float64: ...
def __reduce__(self): ...
@staticmethod
def parse(s: str) -> SparseVector: ...
def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ...
def dot(self, other: Iterable[float]) -> float64: ...
def squared_distance(self, other: Iterable[float]) -> float64: ...
def toArray(self) -> ndarray: ...
def __len__(self) -> int: ...
def __eq__(self, other) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
def __getitem__(self, index: int) -> float64: ...
def __ne__(self, other) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __hash__(self) -> int: ...
class Vectors:
@ -144,13 +140,13 @@ class Vectors:
def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ...
@overload
@staticmethod
def dense(self, *elements: float) -> DenseVector: ...
def dense(*elements: float) -> DenseVector: ...
@overload
@staticmethod
def dense(self, __arr: bytes) -> DenseVector: ...
def dense(__arr: bytes) -> DenseVector: ...
@overload
@staticmethod
def dense(self, __arr: Iterable[float]) -> DenseVector: ...
def dense(__arr: Iterable[float]) -> DenseVector: ...
@staticmethod
def stringify(vector: Vector) -> str: ...
@staticmethod
@ -158,8 +154,6 @@ class Vectors:
@staticmethod
def norm(vector: Vector, p: Union[float, str]) -> float64: ...
@staticmethod
def parse(s: str) -> Vector: ...
@staticmethod
def zeros(size: int) -> DenseVector: ...
class Matrix:
@ -170,7 +164,7 @@ class Matrix:
def __init__(
self, numRows: int, numCols: int, isTransposed: bool = ...
) -> None: ...
def toArray(self): ...
def toArray(self) -> NoReturn: ...
class DenseMatrix(Matrix):
values: Any
@ -186,11 +180,11 @@ class DenseMatrix(Matrix):
values: Iterable[float],
isTransposed: bool = ...,
) -> None: ...
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ...
def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ...
def toArray(self) -> ndarray: ...
def toSparse(self) -> SparseMatrix: ...
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
def __eq__(self, other) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
class SparseMatrix(Matrix):
colPtrs: ndarray
@ -216,11 +210,13 @@ class SparseMatrix(Matrix):
values: Iterable[float],
isTransposed: bool = ...,
) -> None: ...
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ...
def __reduce__(
self,
) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ...
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
def toArray(self) -> ndarray: ...
def toDense(self) -> DenseMatrix: ...
def __eq__(self, other) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
class Matrices:
@overload

View file

@ -51,7 +51,7 @@ class PipelineWriter(MLWriter):
def __init__(self, instance: Pipeline) -> None: ...
def saveImpl(self, path: str) -> None: ...
class PipelineReader(MLReader):
class PipelineReader(MLReader[Pipeline]):
cls: Type[Pipeline]
def __init__(self, cls: Type[Pipeline]) -> None: ...
def load(self, path: str) -> Pipeline: ...
@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter):
def __init__(self, instance: PipelineModel) -> None: ...
def saveImpl(self, path: str) -> None: ...
class PipelineModelReader(MLReader):
class PipelineModelReader(MLReader[PipelineModel]):
cls: Type[PipelineModel]
def __init__(self, cls: Type[PipelineModel]) -> None: ...
def load(self, path: str) -> PipelineModel: ...

View file

@ -414,7 +414,7 @@ class RandomForestRegressionModel(
_TreeEnsembleModel,
_RandomForestRegressorParams,
JavaMLWritable,
JavaMLReadable,
JavaMLReadable[RandomForestRegressionModel],
):
@property
def trees(self) -> List[DecisionTreeRegressionModel]: ...
@ -749,10 +749,10 @@ class _FactorizationMachinesParams(
initStd: Param[float]
solver: Param[str]
def __init__(self, *args: Any): ...
def getFactorSize(self): ...
def getFitLinear(self): ...
def getMiniBatchFraction(self): ...
def getInitStd(self): ...
def getFactorSize(self) -> int: ...
def getFitLinear(self) -> bool: ...
def getMiniBatchFraction(self) -> float: ...
def getInitStd(self) -> float: ...
class FMRegressor(
_JavaRegressor[FMRegressionModel],

View file

@ -118,7 +118,7 @@ class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]):
labels: ndarray
pi: ndarray
theta: ndarray
def __init__(self, labels, pi, theta) -> None: ...
def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ...
@overload
def predict(self, x: VectorLike) -> float64: ...
@overload

View file

@ -63,7 +63,7 @@ class BisectingKMeans:
class KMeansModel(Saveable, Loader[KMeansModel]):
centers: List[ndarray]
def __init__(self, centers: List[ndarray]) -> None: ...
def __init__(self, centers: List[VectorLike]) -> None: ...
@property
def clusterCenters(self) -> List[ndarray]: ...
@property
@ -144,7 +144,9 @@ class PowerIterationClustering:
class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ...
class StreamingKMeansModel(KMeansModel):
def __init__(self, clusterCenters, clusterWeights) -> None: ...
def __init__(
self, clusterCenters: List[VectorLike], clusterWeights: VectorLike
) -> None: ...
@property
def clusterWeights(self) -> List[float64]: ...
centers: ndarray

View file

@ -16,12 +16,20 @@
# specific language governing permissions and limitations
# under the License.
def callJavaFunc(sc, func, *args): ...
def callMLlibFunc(name, *args): ...
from typing import Any, TypeVar
import pyspark.context
from py4j.java_gateway import JavaObject
C = TypeVar("C", bound=type)
def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ...
def callMLlibFunc(name: str, *args: Any) -> Any: ...
class JavaModelWrapper:
def __init__(self, java_model) -> None: ...
def __del__(self): ...
def call(self, name, *a): ...
def __init__(self, java_model: JavaObject) -> None: ...
def __del__(self) -> None: ...
def call(self, name: str, *a: Any) -> Any: ...
def inherit_doc(cls): ...
def inherit_doc(cls: C) -> C: ...

View file

@ -17,7 +17,18 @@
# under the License.
from typing import overload
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from pyspark.ml import linalg as newlinalg
from pyspark.sql.types import StructType, UserDefinedType
from numpy import float64, ndarray # type: ignore[import]
@ -46,7 +57,7 @@ class MatrixUDT(UserDefinedType):
@classmethod
def scalaUDT(cls) -> str: ...
def serialize(
self, obj
self, obj: Matrix
) -> Tuple[
int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool
]: ...
@ -67,8 +78,8 @@ class DenseVector(Vector):
@overload
def __init__(self, __arr: Iterable[float]) -> None: ...
@staticmethod
def parse(s) -> DenseVector: ...
def __reduce__(self) -> Tuple[type, bytes]: ...
def parse(s: str) -> DenseVector: ...
def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ...
def numNonzeros(self) -> int: ...
def norm(self, p: Union[float, str]) -> float64: ...
def dot(self, other: Iterable[float]) -> float64: ...
@ -115,7 +126,7 @@ class SparseVector(Vector):
def __init__(self, size: int, __map: Dict[int, float]) -> None: ...
def numNonzeros(self) -> int: ...
def norm(self, p: Union[float, str]) -> float64: ...
def __reduce__(self): ...
def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ...
@staticmethod
def parse(s: str) -> SparseVector: ...
def dot(self, other: Iterable[float]) -> float64: ...
@ -123,9 +134,9 @@ class SparseVector(Vector):
def toArray(self) -> ndarray: ...
def asML(self) -> newlinalg.SparseVector: ...
def __len__(self) -> int: ...
def __eq__(self, other) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
def __getitem__(self, index: int) -> float64: ...
def __ne__(self, other) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __hash__(self) -> int: ...
class Vectors:
@ -148,13 +159,13 @@ class Vectors:
def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ...
@overload
@staticmethod
def dense(self, *elements: float) -> DenseVector: ...
def dense(*elements: float) -> DenseVector: ...
@overload
@staticmethod
def dense(self, __arr: bytes) -> DenseVector: ...
def dense(__arr: bytes) -> DenseVector: ...
@overload
@staticmethod
def dense(self, __arr: Iterable[float]) -> DenseVector: ...
def dense(__arr: Iterable[float]) -> DenseVector: ...
@staticmethod
def fromML(vec: newlinalg.DenseVector) -> DenseVector: ...
@staticmethod
@ -176,8 +187,8 @@ class Matrix:
def __init__(
self, numRows: int, numCols: int, isTransposed: bool = ...
) -> None: ...
def toArray(self): ...
def asML(self): ...
def toArray(self) -> ndarray: ...
def asML(self) -> newlinalg.Matrix: ...
class DenseMatrix(Matrix):
values: Any
@ -193,12 +204,12 @@ class DenseMatrix(Matrix):
values: Iterable[float],
isTransposed: bool = ...,
) -> None: ...
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ...
def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ...
def toArray(self) -> ndarray: ...
def toSparse(self) -> SparseMatrix: ...
def asML(self) -> newlinalg.DenseMatrix: ...
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
def __eq__(self, other) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
class SparseMatrix(Matrix):
colPtrs: ndarray
@ -224,12 +235,14 @@ class SparseMatrix(Matrix):
values: Iterable[float],
isTransposed: bool = ...,
) -> None: ...
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ...
def __reduce__(
self,
) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ...
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
def toArray(self) -> ndarray: ...
def toDense(self) -> DenseMatrix: ...
def asML(self) -> newlinalg.SparseMatrix: ...
def __eq__(self, other) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
class Matrices:
@overload

View file

@ -90,7 +90,7 @@ class RandomRDDs:
def logNormalVectorRDD(
sc: SparkContext,
mean: float,
std,
std: float,
numRows: int,
numCols: int,
numPartitions: Optional[int] = ...,

View file

@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, Union
import array
from collections import namedtuple
@ -27,7 +27,7 @@ from pyspark.mllib.common import JavaModelWrapper
from pyspark.mllib.util import JavaLoader, JavaSaveable
class Rating(namedtuple("Rating", ["user", "product", "rating"])):
def __reduce__(self): ...
def __reduce__(self) -> Tuple[Type[Rating], Tuple[int, int, float]]: ...
class MatrixFactorizationModel(
JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel]

View file

@ -65,5 +65,5 @@ class Statistics:
def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ...
@staticmethod
def kolmogorovSmirnovTest(
data, distName: Literal["norm"] = ..., *params: float
data: RDD[float], distName: Literal["norm"] = ..., *params: float
) -> KolmogorovSmirnovTestResult: ...

View file

@ -85,12 +85,16 @@ class PythonEvalType:
SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType
class BoundedFloat(float):
def __new__(cls, mean: float, confidence: float, low: float, high: float): ...
def __new__(
cls, mean: float, confidence: float, low: float, high: float
) -> BoundedFloat: ...
class Partitioner:
numPartitions: int
partitionFunc: Callable[[Any], int]
def __init__(self, numPartitions, partitionFunc) -> None: ...
def __init__(
self, numPartitions: int, partitionFunc: Callable[[Any], int]
) -> None: ...
def __eq__(self, other: Any) -> bool: ...
def __call__(self, k: Any) -> int: ...

View file

@ -49,7 +49,7 @@ class ResourceProfileBuilder:
def __init__(self) -> None: ...
def require(
self, resourceRequest: Union[ExecutorResourceRequest, TaskResourceRequests]
): ...
) -> ResourceProfileBuilder: ...
def clearExecutorResourceRequests(self) -> None: ...
def clearTaskResourceRequests(self) -> None: ...
@property

View file

@ -32,7 +32,7 @@ from pyspark.sql.window import WindowSpec
from py4j.java_gateway import JavaObject # type: ignore[import]
class Column:
def __init__(self, JavaObject) -> None: ...
def __init__(self, jc: JavaObject) -> None: ...
def __neg__(self) -> Column: ...
def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
@ -105,7 +105,11 @@ class Column:
def name(self, *alias: str) -> Column: ...
def cast(self, dataType: Union[DataType, str]) -> Column: ...
def astype(self, dataType: Union[DataType, str]) -> Column: ...
def between(self, lowerBound, upperBound) -> Column: ...
def between(
self,
lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
) -> Column: ...
def when(self, condition: Column, value: Any) -> Column: ...
def otherwise(self, value: Any) -> Column: ...
def over(self, window: WindowSpec) -> Column: ...

View file

@ -43,14 +43,14 @@ class SQLContext:
sparkSession: SparkSession
def __init__(
self,
sparkContext,
sparkContext: SparkContext,
sparkSession: Optional[SparkSession] = ...,
jsqlContext: Optional[JavaObject] = ...,
) -> None: ...
@classmethod
def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ...
def newSession(self) -> SQLContext: ...
def setConf(self, key: str, value) -> None: ...
def setConf(self, key: str, value: Union[bool, int, str]) -> None: ...
def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ...
@property
def udf(self) -> UDFRegistration: ...
@ -116,7 +116,7 @@ class SQLContext:
path: Optional[str] = ...,
source: Optional[str] = ...,
schema: Optional[StructType] = ...,
**options
**options: str
) -> DataFrame: ...
def sql(self, sqlQuery: str) -> DataFrame: ...
def table(self, tableName: str) -> DataFrame: ...

View file

@ -65,13 +65,13 @@ def round(col: ColumnOrName, scale: int = ...) -> Column: ...
def bround(col: ColumnOrName, scale: int = ...) -> Column: ...
def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ...
def shiftRight(col: ColumnOrName, numBits: int) -> Column: ...
def shiftRightUnsigned(col, numBits) -> Column: ...
def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ...
def spark_partition_id() -> Column: ...
def expr(str: str) -> Column: ...
def struct(*cols: ColumnOrName) -> Column: ...
def greatest(*cols: ColumnOrName) -> Column: ...
def least(*cols: Column) -> Column: ...
def when(condition: Column, value) -> Column: ...
def when(condition: Column, value: Any) -> Column: ...
@overload
def log(arg1: ColumnOrName) -> Column: ...
@overload
@ -174,7 +174,9 @@ def create_map(*cols: ColumnOrName) -> Column: ...
def array(*cols: ColumnOrName) -> Column: ...
def array_contains(col: ColumnOrName, value: Any) -> Column: ...
def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ...
def slice(x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]) -> Column: ...
def slice(
x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]
) -> Column: ...
def array_join(
col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ...
) -> Column: ...

View file

@ -17,7 +17,8 @@
# under the License.
from typing import overload
from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from types import TracebackType
from py4j.java_gateway import JavaObject # type: ignore[import]
@ -122,4 +123,9 @@ class SparkSession(SparkConversionMixin):
def streams(self) -> StreamingQueryManager: ...
def stop(self) -> None: ...
def __enter__(self) -> SparkSession: ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None: ...

View file

@ -17,7 +17,8 @@
# under the License.
from typing import overload
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar
from py4j.java_gateway import JavaGateway, JavaObject
import datetime
T = TypeVar("T")
@ -37,7 +38,7 @@ class DataType:
def fromInternal(self, obj: Any) -> Any: ...
class DataTypeSingleton(type):
def __call__(cls): ...
def __call__(cls: Type[T]) -> T: ... # type: ignore
class NullType(DataType, metaclass=DataTypeSingleton): ...
class AtomicType(DataType): ...
@ -85,8 +86,8 @@ class ShortType(IntegralType):
class ArrayType(DataType):
elementType: DataType
containsNull: bool
def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ...
def simpleString(self): ...
def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ...
def simpleString(self) -> str: ...
def jsonValue(self) -> Dict[str, Any]: ...
@classmethod
def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ...
@ -197,8 +198,8 @@ class Row(tuple):
class DateConverter:
def can_convert(self, obj: Any) -> bool: ...
def convert(self, obj, gateway_client) -> Any: ...
def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ...
class DatetimeConverter:
def can_convert(self, obj) -> bool: ...
def convert(self, obj, gateway_client) -> Any: ...
def can_convert(self, obj: Any) -> bool: ...
def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ...

View file

@ -18,8 +18,9 @@
from typing import Any, Callable, Optional
from pyspark.sql._typing import ColumnOrName, DataTypeOrString
from pyspark.sql._typing import ColumnOrName, DataTypeOrString, UserDefinedFunctionLike
from pyspark.sql.column import Column
from pyspark.sql.types import DataType
import pyspark.sql.session
class UserDefinedFunction:
@ -35,7 +36,7 @@ class UserDefinedFunction:
deterministic: bool = ...,
) -> None: ...
@property
def returnType(self): ...
def returnType(self) -> DataType: ...
def __call__(self, *cols: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> UserDefinedFunction: ...
@ -47,7 +48,7 @@ class UDFRegistration:
name: str,
f: Callable[..., Any],
returnType: Optional[DataTypeOrString] = ...,
): ...
) -> UserDefinedFunctionLike: ...
def registerJavaFunction(
self,
name: str,

View file

@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, List, Optional, TypeVar
from py4j.java_gateway import JavaObject # type: ignore[import]

View file

@ -30,9 +30,12 @@ from typing import (
)
import datetime
from pyspark.rdd import RDD
import pyspark.serializers
from pyspark.storagelevel import StorageLevel
import pyspark.streaming.context
from py4j.java_gateway import JavaObject
S = TypeVar("S")
T = TypeVar("T")
U = TypeVar("U")
@ -42,7 +45,12 @@ V = TypeVar("V")
class DStream(Generic[T]):
is_cached: bool
is_checkpointed: bool
def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ...
def __init__(
self,
jdstream: JavaObject,
ssc: pyspark.streaming.context.StreamingContext,
jrdd_deserializer: pyspark.serializers.Serializer,
) -> None: ...
def context(self) -> pyspark.streaming.context.StreamingContext: ...
def count(self) -> DStream[int]: ...
def filter(self, f: Callable[[T], bool]) -> DStream[T]: ...

View file

@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Callable, Optional, TypeVar
from typing import Callable, Optional, TypeVar
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.dstream import DStream