[SPARK-33457][PYTHON] Adjust mypy configuration
### What changes were proposed in this pull request? This pull request: - Adds following flags to the main mypy configuration: - [`strict_optional`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-strict_optional) - [`no_implicit_optional`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-no_implicit_optional) - [`disallow_untyped_defs`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-disallow_untyped_calls) These flags are enabled only for public API and disabled for tests and internal modules. Additionally, these PR fixes missing annotations. ### Why are the changes needed? Primary reason to propose this changes is to use standard configuration as used by typeshed project. This will allow us to be more strict, especially when interacting with JVM code. See for example https://github.com/apache/spark/pull/29122#pullrequestreview-513112882 Additionally, it will allow us to detect cases where annotations have unintentionally omitted. ### Does this PR introduce _any_ user-facing change? Annotations only. ### How was this patch tested? `dev/lint-python`. Closes #30382 from zero323/SPARK-33457. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
95b6dabc33
commit
665817bd4f
|
@ -16,10 +16,97 @@
|
|||
;
|
||||
|
||||
[mypy]
|
||||
strict_optional = True
|
||||
no_implicit_optional = True
|
||||
disallow_untyped_defs = True
|
||||
|
||||
; Allow untyped def in internal modules and tests
|
||||
|
||||
[mypy-pyspark.daemon]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.find_spark_home]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark._globals]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.install]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.java_gateway]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.join]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.ml.tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.mllib.tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.rddsampler]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.resource.tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.serializers]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.shuffle]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.streaming.tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.streaming.util]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.pandas.serializers]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.pandas.types]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.pandas.typehints]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.pandas.utils]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.pandas._typing.protocols.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.sql.utils]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.testing.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.traceback_utils]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.util]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.worker]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
; Ignore errors in embedded third party code
|
||||
|
||||
[mypy-pyspark.cloudpickle.*]
|
||||
ignore_errors = True
|
||||
|
||||
; Ignore missing imports for external untyped packages
|
||||
|
||||
[mypy-py4j.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# under the License.
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, Generic, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -32,14 +32,14 @@ class Broadcast(Generic[T]):
|
|||
path: Optional[Any] = ...,
|
||||
sock_file: Optional[Any] = ...,
|
||||
) -> None: ...
|
||||
def dump(self, value: Any, f: Any) -> None: ...
|
||||
def load_from_path(self, path: Any): ...
|
||||
def load(self, file: Any): ...
|
||||
def dump(self, value: T, f: Any) -> None: ...
|
||||
def load_from_path(self, path: Any) -> T: ...
|
||||
def load(self, file: Any) -> T: ...
|
||||
@property
|
||||
def value(self) -> T: ...
|
||||
def unpersist(self, blocking: bool = ...) -> None: ...
|
||||
def destroy(self, blocking: bool = ...) -> None: ...
|
||||
def __reduce__(self): ...
|
||||
def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ...
|
||||
|
||||
class BroadcastPickleRegistry(threading.local):
|
||||
def __init__(self) -> None: ...
|
||||
|
|
|
@ -16,7 +16,19 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
from types import TracebackType
|
||||
|
||||
from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import]
|
||||
|
||||
|
@ -51,9 +63,14 @@ class SparkContext:
|
|||
jsc: Optional[JavaObject] = ...,
|
||||
profiler_cls: type = ...,
|
||||
) -> None: ...
|
||||
def __getnewargs__(self): ...
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, type, value, trace): ...
|
||||
def __getnewargs__(self) -> NoReturn: ...
|
||||
def __enter__(self) -> SparkContext: ...
|
||||
def __exit__(
|
||||
self,
|
||||
type: Optional[Type[BaseException]],
|
||||
value: Optional[BaseException],
|
||||
trace: Optional[TracebackType],
|
||||
) -> None: ...
|
||||
@classmethod
|
||||
def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ...
|
||||
def setLogLevel(self, logLevel: str) -> None: ...
|
||||
|
|
|
@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier(
|
|||
class _JavaProbabilisticClassificationModel(
|
||||
ProbabilisticClassificationModel, _JavaClassificationModel[T]
|
||||
):
|
||||
def predictProbability(self, value: Any): ...
|
||||
def predictProbability(self, value: Vector) -> Vector: ...
|
||||
|
||||
class _ClassificationSummary(JavaWrapper):
|
||||
@property
|
||||
|
@ -543,7 +543,7 @@ class RandomForestClassificationModel(
|
|||
@property
|
||||
def trees(self) -> List[DecisionTreeClassificationModel]: ...
|
||||
def summary(self) -> RandomForestClassificationTrainingSummary: ...
|
||||
def evaluate(self, dataset) -> RandomForestClassificationSummary: ...
|
||||
def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ...
|
||||
|
||||
class RandomForestClassificationSummary(_ClassificationSummary): ...
|
||||
class RandomForestClassificationTrainingSummary(
|
||||
|
@ -891,7 +891,7 @@ class FMClassifier(
|
|||
solver: str = ...,
|
||||
thresholds: Optional[Any] = ...,
|
||||
seed: Optional[Any] = ...,
|
||||
): ...
|
||||
) -> FMClassifier: ...
|
||||
def setFactorSize(self, value: int) -> FMClassifier: ...
|
||||
def setFitLinear(self, value: bool) -> FMClassifier: ...
|
||||
def setMiniBatchFraction(self, value: float) -> FMClassifier: ...
|
||||
|
|
|
@ -16,5 +16,11 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
def callJavaFunc(sc, func, *args): ...
|
||||
def inherit_doc(cls): ...
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import pyspark.context
|
||||
|
||||
C = TypeVar("C", bound=type)
|
||||
|
||||
def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ...
|
||||
def inherit_doc(cls: C) -> C: ...
|
||||
|
|
|
@ -39,9 +39,12 @@ from pyspark.ml.param.shared import (
|
|||
HasWeightCol,
|
||||
)
|
||||
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
|
||||
class Evaluator(Params, metaclass=abc.ABCMeta):
|
||||
def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ...
|
||||
def evaluate(
|
||||
self, dataset: DataFrame, params: Optional[ParamMap] = ...
|
||||
) -> float: ...
|
||||
def isLargerBetter(self) -> bool: ...
|
||||
|
||||
class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta):
|
||||
|
@ -75,7 +78,6 @@ class BinaryClassificationEvaluator(
|
|||
def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ...
|
||||
def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ...
|
||||
def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ...
|
||||
|
||||
def setParams(
|
||||
self,
|
||||
*,
|
||||
|
|
|
@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol):
|
|||
def getNumHashTables(self) -> int: ...
|
||||
|
||||
class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable):
|
||||
def setNumHashTables(self: P, value) -> P: ...
|
||||
def setInputCol(self: P, value) -> P: ...
|
||||
def setOutputCol(self: P, value) -> P: ...
|
||||
def setNumHashTables(self: P, value: int) -> P: ...
|
||||
def setInputCol(self: P, value: str) -> P: ...
|
||||
def setOutputCol(self: P, value: str) -> P: ...
|
||||
|
||||
class _LSHModel(JavaModel, _LSHParams):
|
||||
def setInputCol(self: P, value: str) -> P: ...
|
||||
|
@ -1518,7 +1518,7 @@ class ChiSqSelector(
|
|||
fpr: float = ...,
|
||||
fdr: float = ...,
|
||||
fwe: float = ...
|
||||
): ...
|
||||
) -> ChiSqSelector: ...
|
||||
def setSelectorType(self, value: str) -> ChiSqSelector: ...
|
||||
def setNumTopFeatures(self, value: int) -> ChiSqSelector: ...
|
||||
def setPercentile(self, value: float) -> ChiSqSelector: ...
|
||||
|
@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
|
|||
def getVarianceThreshold(self) -> float: ...
|
||||
|
||||
class VarianceThresholdSelector(
|
||||
JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
|
||||
JavaEstimator[VarianceThresholdSelectorModel],
|
||||
_VarianceThresholdSelectorParams,
|
||||
JavaMLReadable[VarianceThresholdSelector],
|
||||
JavaMLWritable,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1615,13 +1618,16 @@ class VarianceThresholdSelector(
|
|||
featuresCol: str = ...,
|
||||
outputCol: Optional[str] = ...,
|
||||
varianceThreshold: float = ...,
|
||||
): ...
|
||||
) -> VarianceThresholdSelector: ...
|
||||
def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ...
|
||||
def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ...
|
||||
def setOutputCol(self, value: str) -> VarianceThresholdSelector: ...
|
||||
|
||||
class VarianceThresholdSelectorModel(
|
||||
JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
|
||||
JavaModel,
|
||||
_VarianceThresholdSelectorParams,
|
||||
JavaMLReadable[VarianceThresholdSelectorModel],
|
||||
JavaMLWritable,
|
||||
):
|
||||
def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ...
|
||||
def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ...
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# under the License.
|
||||
|
||||
from typing import overload
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union
|
||||
|
||||
from pyspark.ml import linalg as newlinalg # noqa: F401
|
||||
from pyspark.sql.types import StructType, UserDefinedType
|
||||
|
@ -45,7 +45,7 @@ class MatrixUDT(UserDefinedType):
|
|||
@classmethod
|
||||
def scalaUDT(cls) -> str: ...
|
||||
def serialize(
|
||||
self, obj
|
||||
self, obj: Matrix
|
||||
) -> Tuple[
|
||||
int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool
|
||||
]: ...
|
||||
|
@ -64,9 +64,7 @@ class DenseVector(Vector):
|
|||
def __init__(self, __arr: bytes) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __arr: Iterable[float]) -> None: ...
|
||||
@staticmethod
|
||||
def parse(s) -> DenseVector: ...
|
||||
def __reduce__(self) -> Tuple[type, bytes]: ...
|
||||
def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ...
|
||||
def numNonzeros(self) -> int: ...
|
||||
def norm(self, p: Union[float, str]) -> float64: ...
|
||||
def dot(self, other: Iterable[float]) -> float64: ...
|
||||
|
@ -112,16 +110,14 @@ class SparseVector(Vector):
|
|||
def __init__(self, size: int, __map: Dict[int, float]) -> None: ...
|
||||
def numNonzeros(self) -> int: ...
|
||||
def norm(self, p: Union[float, str]) -> float64: ...
|
||||
def __reduce__(self): ...
|
||||
@staticmethod
|
||||
def parse(s: str) -> SparseVector: ...
|
||||
def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ...
|
||||
def dot(self, other: Iterable[float]) -> float64: ...
|
||||
def squared_distance(self, other: Iterable[float]) -> float64: ...
|
||||
def toArray(self) -> ndarray: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
def __getitem__(self, index: int) -> float64: ...
|
||||
def __ne__(self, other) -> bool: ...
|
||||
def __ne__(self, other: Any) -> bool: ...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
class Vectors:
|
||||
|
@ -144,13 +140,13 @@ class Vectors:
|
|||
def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def dense(self, *elements: float) -> DenseVector: ...
|
||||
def dense(*elements: float) -> DenseVector: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def dense(self, __arr: bytes) -> DenseVector: ...
|
||||
def dense(__arr: bytes) -> DenseVector: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def dense(self, __arr: Iterable[float]) -> DenseVector: ...
|
||||
def dense(__arr: Iterable[float]) -> DenseVector: ...
|
||||
@staticmethod
|
||||
def stringify(vector: Vector) -> str: ...
|
||||
@staticmethod
|
||||
|
@ -158,8 +154,6 @@ class Vectors:
|
|||
@staticmethod
|
||||
def norm(vector: Vector, p: Union[float, str]) -> float64: ...
|
||||
@staticmethod
|
||||
def parse(s: str) -> Vector: ...
|
||||
@staticmethod
|
||||
def zeros(size: int) -> DenseVector: ...
|
||||
|
||||
class Matrix:
|
||||
|
@ -170,7 +164,7 @@ class Matrix:
|
|||
def __init__(
|
||||
self, numRows: int, numCols: int, isTransposed: bool = ...
|
||||
) -> None: ...
|
||||
def toArray(self): ...
|
||||
def toArray(self) -> NoReturn: ...
|
||||
|
||||
class DenseMatrix(Matrix):
|
||||
values: Any
|
||||
|
@ -186,11 +180,11 @@ class DenseMatrix(Matrix):
|
|||
values: Iterable[float],
|
||||
isTransposed: bool = ...,
|
||||
) -> None: ...
|
||||
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ...
|
||||
def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ...
|
||||
def toArray(self) -> ndarray: ...
|
||||
def toSparse(self) -> SparseMatrix: ...
|
||||
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
|
||||
class SparseMatrix(Matrix):
|
||||
colPtrs: ndarray
|
||||
|
@ -216,11 +210,13 @@ class SparseMatrix(Matrix):
|
|||
values: Iterable[float],
|
||||
isTransposed: bool = ...,
|
||||
) -> None: ...
|
||||
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ...
|
||||
def __reduce__(
|
||||
self,
|
||||
) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ...
|
||||
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
|
||||
def toArray(self) -> ndarray: ...
|
||||
def toDense(self) -> DenseMatrix: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
|
||||
class Matrices:
|
||||
@overload
|
||||
|
|
|
@ -51,7 +51,7 @@ class PipelineWriter(MLWriter):
|
|||
def __init__(self, instance: Pipeline) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class PipelineReader(MLReader):
|
||||
class PipelineReader(MLReader[Pipeline]):
|
||||
cls: Type[Pipeline]
|
||||
def __init__(self, cls: Type[Pipeline]) -> None: ...
|
||||
def load(self, path: str) -> Pipeline: ...
|
||||
|
@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter):
|
|||
def __init__(self, instance: PipelineModel) -> None: ...
|
||||
def saveImpl(self, path: str) -> None: ...
|
||||
|
||||
class PipelineModelReader(MLReader):
|
||||
class PipelineModelReader(MLReader[PipelineModel]):
|
||||
cls: Type[PipelineModel]
|
||||
def __init__(self, cls: Type[PipelineModel]) -> None: ...
|
||||
def load(self, path: str) -> PipelineModel: ...
|
||||
|
|
|
@ -414,7 +414,7 @@ class RandomForestRegressionModel(
|
|||
_TreeEnsembleModel,
|
||||
_RandomForestRegressorParams,
|
||||
JavaMLWritable,
|
||||
JavaMLReadable,
|
||||
JavaMLReadable[RandomForestRegressionModel],
|
||||
):
|
||||
@property
|
||||
def trees(self) -> List[DecisionTreeRegressionModel]: ...
|
||||
|
@ -749,10 +749,10 @@ class _FactorizationMachinesParams(
|
|||
initStd: Param[float]
|
||||
solver: Param[str]
|
||||
def __init__(self, *args: Any): ...
|
||||
def getFactorSize(self): ...
|
||||
def getFitLinear(self): ...
|
||||
def getMiniBatchFraction(self): ...
|
||||
def getInitStd(self): ...
|
||||
def getFactorSize(self) -> int: ...
|
||||
def getFitLinear(self) -> bool: ...
|
||||
def getMiniBatchFraction(self) -> float: ...
|
||||
def getInitStd(self) -> float: ...
|
||||
|
||||
class FMRegressor(
|
||||
_JavaRegressor[FMRegressionModel],
|
||||
|
|
|
@ -118,7 +118,7 @@ class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]):
|
|||
labels: ndarray
|
||||
pi: ndarray
|
||||
theta: ndarray
|
||||
def __init__(self, labels, pi, theta) -> None: ...
|
||||
def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ...
|
||||
@overload
|
||||
def predict(self, x: VectorLike) -> float64: ...
|
||||
@overload
|
||||
|
|
|
@ -63,7 +63,7 @@ class BisectingKMeans:
|
|||
|
||||
class KMeansModel(Saveable, Loader[KMeansModel]):
|
||||
centers: List[ndarray]
|
||||
def __init__(self, centers: List[ndarray]) -> None: ...
|
||||
def __init__(self, centers: List[VectorLike]) -> None: ...
|
||||
@property
|
||||
def clusterCenters(self) -> List[ndarray]: ...
|
||||
@property
|
||||
|
@ -144,7 +144,9 @@ class PowerIterationClustering:
|
|||
class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ...
|
||||
|
||||
class StreamingKMeansModel(KMeansModel):
|
||||
def __init__(self, clusterCenters, clusterWeights) -> None: ...
|
||||
def __init__(
|
||||
self, clusterCenters: List[VectorLike], clusterWeights: VectorLike
|
||||
) -> None: ...
|
||||
@property
|
||||
def clusterWeights(self) -> List[float64]: ...
|
||||
centers: ndarray
|
||||
|
|
|
@ -16,12 +16,20 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
def callJavaFunc(sc, func, *args): ...
|
||||
def callMLlibFunc(name, *args): ...
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import pyspark.context
|
||||
|
||||
from py4j.java_gateway import JavaObject
|
||||
|
||||
C = TypeVar("C", bound=type)
|
||||
|
||||
def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ...
|
||||
def callMLlibFunc(name: str, *args: Any) -> Any: ...
|
||||
|
||||
class JavaModelWrapper:
|
||||
def __init__(self, java_model) -> None: ...
|
||||
def __del__(self): ...
|
||||
def call(self, name, *a): ...
|
||||
def __init__(self, java_model: JavaObject) -> None: ...
|
||||
def __del__(self) -> None: ...
|
||||
def call(self, name: str, *a: Any) -> Any: ...
|
||||
|
||||
def inherit_doc(cls): ...
|
||||
def inherit_doc(cls: C) -> C: ...
|
||||
|
|
|
@ -17,7 +17,18 @@
|
|||
# under the License.
|
||||
|
||||
from typing import overload
|
||||
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from pyspark.ml import linalg as newlinalg
|
||||
from pyspark.sql.types import StructType, UserDefinedType
|
||||
from numpy import float64, ndarray # type: ignore[import]
|
||||
|
@ -46,7 +57,7 @@ class MatrixUDT(UserDefinedType):
|
|||
@classmethod
|
||||
def scalaUDT(cls) -> str: ...
|
||||
def serialize(
|
||||
self, obj
|
||||
self, obj: Matrix
|
||||
) -> Tuple[
|
||||
int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool
|
||||
]: ...
|
||||
|
@ -67,8 +78,8 @@ class DenseVector(Vector):
|
|||
@overload
|
||||
def __init__(self, __arr: Iterable[float]) -> None: ...
|
||||
@staticmethod
|
||||
def parse(s) -> DenseVector: ...
|
||||
def __reduce__(self) -> Tuple[type, bytes]: ...
|
||||
def parse(s: str) -> DenseVector: ...
|
||||
def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ...
|
||||
def numNonzeros(self) -> int: ...
|
||||
def norm(self, p: Union[float, str]) -> float64: ...
|
||||
def dot(self, other: Iterable[float]) -> float64: ...
|
||||
|
@ -115,7 +126,7 @@ class SparseVector(Vector):
|
|||
def __init__(self, size: int, __map: Dict[int, float]) -> None: ...
|
||||
def numNonzeros(self) -> int: ...
|
||||
def norm(self, p: Union[float, str]) -> float64: ...
|
||||
def __reduce__(self): ...
|
||||
def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ...
|
||||
@staticmethod
|
||||
def parse(s: str) -> SparseVector: ...
|
||||
def dot(self, other: Iterable[float]) -> float64: ...
|
||||
|
@ -123,9 +134,9 @@ class SparseVector(Vector):
|
|||
def toArray(self) -> ndarray: ...
|
||||
def asML(self) -> newlinalg.SparseVector: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
def __getitem__(self, index: int) -> float64: ...
|
||||
def __ne__(self, other) -> bool: ...
|
||||
def __ne__(self, other: Any) -> bool: ...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
class Vectors:
|
||||
|
@ -148,13 +159,13 @@ class Vectors:
|
|||
def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def dense(self, *elements: float) -> DenseVector: ...
|
||||
def dense(*elements: float) -> DenseVector: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def dense(self, __arr: bytes) -> DenseVector: ...
|
||||
def dense(__arr: bytes) -> DenseVector: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def dense(self, __arr: Iterable[float]) -> DenseVector: ...
|
||||
def dense(__arr: Iterable[float]) -> DenseVector: ...
|
||||
@staticmethod
|
||||
def fromML(vec: newlinalg.DenseVector) -> DenseVector: ...
|
||||
@staticmethod
|
||||
|
@ -176,8 +187,8 @@ class Matrix:
|
|||
def __init__(
|
||||
self, numRows: int, numCols: int, isTransposed: bool = ...
|
||||
) -> None: ...
|
||||
def toArray(self): ...
|
||||
def asML(self): ...
|
||||
def toArray(self) -> ndarray: ...
|
||||
def asML(self) -> newlinalg.Matrix: ...
|
||||
|
||||
class DenseMatrix(Matrix):
|
||||
values: Any
|
||||
|
@ -193,12 +204,12 @@ class DenseMatrix(Matrix):
|
|||
values: Iterable[float],
|
||||
isTransposed: bool = ...,
|
||||
) -> None: ...
|
||||
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ...
|
||||
def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ...
|
||||
def toArray(self) -> ndarray: ...
|
||||
def toSparse(self) -> SparseMatrix: ...
|
||||
def asML(self) -> newlinalg.DenseMatrix: ...
|
||||
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
|
||||
class SparseMatrix(Matrix):
|
||||
colPtrs: ndarray
|
||||
|
@ -224,12 +235,14 @@ class SparseMatrix(Matrix):
|
|||
values: Iterable[float],
|
||||
isTransposed: bool = ...,
|
||||
) -> None: ...
|
||||
def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ...
|
||||
def __reduce__(
|
||||
self,
|
||||
) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ...
|
||||
def __getitem__(self, indices: Tuple[int, int]) -> float64: ...
|
||||
def toArray(self) -> ndarray: ...
|
||||
def toDense(self) -> DenseMatrix: ...
|
||||
def asML(self) -> newlinalg.SparseMatrix: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
|
||||
class Matrices:
|
||||
@overload
|
||||
|
|
|
@ -90,7 +90,7 @@ class RandomRDDs:
|
|||
def logNormalVectorRDD(
|
||||
sc: SparkContext,
|
||||
mean: float,
|
||||
std,
|
||||
std: float,
|
||||
numRows: int,
|
||||
numCols: int,
|
||||
numPartitions: Optional[int] = ...,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import array
|
||||
from collections import namedtuple
|
||||
|
@ -27,7 +27,7 @@ from pyspark.mllib.common import JavaModelWrapper
|
|||
from pyspark.mllib.util import JavaLoader, JavaSaveable
|
||||
|
||||
class Rating(namedtuple("Rating", ["user", "product", "rating"])):
|
||||
def __reduce__(self): ...
|
||||
def __reduce__(self) -> Tuple[Type[Rating], Tuple[int, int, float]]: ...
|
||||
|
||||
class MatrixFactorizationModel(
|
||||
JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel]
|
||||
|
|
|
@ -65,5 +65,5 @@ class Statistics:
|
|||
def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ...
|
||||
@staticmethod
|
||||
def kolmogorovSmirnovTest(
|
||||
data, distName: Literal["norm"] = ..., *params: float
|
||||
data: RDD[float], distName: Literal["norm"] = ..., *params: float
|
||||
) -> KolmogorovSmirnovTestResult: ...
|
||||
|
|
|
@ -85,12 +85,16 @@ class PythonEvalType:
|
|||
SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType
|
||||
|
||||
class BoundedFloat(float):
|
||||
def __new__(cls, mean: float, confidence: float, low: float, high: float): ...
|
||||
def __new__(
|
||||
cls, mean: float, confidence: float, low: float, high: float
|
||||
) -> BoundedFloat: ...
|
||||
|
||||
class Partitioner:
|
||||
numPartitions: int
|
||||
partitionFunc: Callable[[Any], int]
|
||||
def __init__(self, numPartitions, partitionFunc) -> None: ...
|
||||
def __init__(
|
||||
self, numPartitions: int, partitionFunc: Callable[[Any], int]
|
||||
) -> None: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
def __call__(self, k: Any) -> int: ...
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ class ResourceProfileBuilder:
|
|||
def __init__(self) -> None: ...
|
||||
def require(
|
||||
self, resourceRequest: Union[ExecutorResourceRequest, TaskResourceRequests]
|
||||
): ...
|
||||
) -> ResourceProfileBuilder: ...
|
||||
def clearExecutorResourceRequests(self) -> None: ...
|
||||
def clearTaskResourceRequests(self) -> None: ...
|
||||
@property
|
||||
|
|
|
@ -32,7 +32,7 @@ from pyspark.sql.window import WindowSpec
|
|||
from py4j.java_gateway import JavaObject # type: ignore[import]
|
||||
|
||||
class Column:
|
||||
def __init__(self, JavaObject) -> None: ...
|
||||
def __init__(self, jc: JavaObject) -> None: ...
|
||||
def __neg__(self) -> Column: ...
|
||||
def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
|
||||
def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
|
||||
|
@ -105,7 +105,11 @@ class Column:
|
|||
def name(self, *alias: str) -> Column: ...
|
||||
def cast(self, dataType: Union[DataType, str]) -> Column: ...
|
||||
def astype(self, dataType: Union[DataType, str]) -> Column: ...
|
||||
def between(self, lowerBound, upperBound) -> Column: ...
|
||||
def between(
|
||||
self,
|
||||
lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
|
||||
upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
|
||||
) -> Column: ...
|
||||
def when(self, condition: Column, value: Any) -> Column: ...
|
||||
def otherwise(self, value: Any) -> Column: ...
|
||||
def over(self, window: WindowSpec) -> Column: ...
|
||||
|
|
|
@ -43,14 +43,14 @@ class SQLContext:
|
|||
sparkSession: SparkSession
|
||||
def __init__(
|
||||
self,
|
||||
sparkContext,
|
||||
sparkContext: SparkContext,
|
||||
sparkSession: Optional[SparkSession] = ...,
|
||||
jsqlContext: Optional[JavaObject] = ...,
|
||||
) -> None: ...
|
||||
@classmethod
|
||||
def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ...
|
||||
def newSession(self) -> SQLContext: ...
|
||||
def setConf(self, key: str, value) -> None: ...
|
||||
def setConf(self, key: str, value: Union[bool, int, str]) -> None: ...
|
||||
def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ...
|
||||
@property
|
||||
def udf(self) -> UDFRegistration: ...
|
||||
|
@ -116,7 +116,7 @@ class SQLContext:
|
|||
path: Optional[str] = ...,
|
||||
source: Optional[str] = ...,
|
||||
schema: Optional[StructType] = ...,
|
||||
**options
|
||||
**options: str
|
||||
) -> DataFrame: ...
|
||||
def sql(self, sqlQuery: str) -> DataFrame: ...
|
||||
def table(self, tableName: str) -> DataFrame: ...
|
||||
|
|
|
@ -65,13 +65,13 @@ def round(col: ColumnOrName, scale: int = ...) -> Column: ...
|
|||
def bround(col: ColumnOrName, scale: int = ...) -> Column: ...
|
||||
def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftRight(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftRightUnsigned(col, numBits) -> Column: ...
|
||||
def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def spark_partition_id() -> Column: ...
|
||||
def expr(str: str) -> Column: ...
|
||||
def struct(*cols: ColumnOrName) -> Column: ...
|
||||
def greatest(*cols: ColumnOrName) -> Column: ...
|
||||
def least(*cols: Column) -> Column: ...
|
||||
def when(condition: Column, value) -> Column: ...
|
||||
def when(condition: Column, value: Any) -> Column: ...
|
||||
@overload
|
||||
def log(arg1: ColumnOrName) -> Column: ...
|
||||
@overload
|
||||
|
@ -174,7 +174,9 @@ def create_map(*cols: ColumnOrName) -> Column: ...
|
|||
def array(*cols: ColumnOrName) -> Column: ...
|
||||
def array_contains(col: ColumnOrName, value: Any) -> Column: ...
|
||||
def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ...
|
||||
def slice(x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]) -> Column: ...
|
||||
def slice(
|
||||
x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]
|
||||
) -> Column: ...
|
||||
def array_join(
|
||||
col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ...
|
||||
) -> Column: ...
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
# under the License.
|
||||
|
||||
from typing import overload
|
||||
from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from types import TracebackType
|
||||
|
||||
from py4j.java_gateway import JavaObject # type: ignore[import]
|
||||
|
||||
|
@ -122,4 +123,9 @@ class SparkSession(SparkConversionMixin):
|
|||
def streams(self) -> StreamingQueryManager: ...
|
||||
def stop(self) -> None: ...
|
||||
def __enter__(self) -> SparkSession: ...
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None: ...
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
# under the License.
|
||||
|
||||
from typing import overload
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar
|
||||
from py4j.java_gateway import JavaGateway, JavaObject
|
||||
import datetime
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -37,7 +38,7 @@ class DataType:
|
|||
def fromInternal(self, obj: Any) -> Any: ...
|
||||
|
||||
class DataTypeSingleton(type):
|
||||
def __call__(cls): ...
|
||||
def __call__(cls: Type[T]) -> T: ... # type: ignore
|
||||
|
||||
class NullType(DataType, metaclass=DataTypeSingleton): ...
|
||||
class AtomicType(DataType): ...
|
||||
|
@ -85,8 +86,8 @@ class ShortType(IntegralType):
|
|||
class ArrayType(DataType):
|
||||
elementType: DataType
|
||||
containsNull: bool
|
||||
def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ...
|
||||
def simpleString(self): ...
|
||||
def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ...
|
||||
def simpleString(self) -> str: ...
|
||||
def jsonValue(self) -> Dict[str, Any]: ...
|
||||
@classmethod
|
||||
def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ...
|
||||
|
@ -197,8 +198,8 @@ class Row(tuple):
|
|||
|
||||
class DateConverter:
|
||||
def can_convert(self, obj: Any) -> bool: ...
|
||||
def convert(self, obj, gateway_client) -> Any: ...
|
||||
def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ...
|
||||
|
||||
class DatetimeConverter:
|
||||
def can_convert(self, obj) -> bool: ...
|
||||
def convert(self, obj, gateway_client) -> Any: ...
|
||||
def can_convert(self, obj: Any) -> bool: ...
|
||||
def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ...
|
||||
|
|
|
@ -18,8 +18,9 @@
|
|||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from pyspark.sql._typing import ColumnOrName, DataTypeOrString
|
||||
from pyspark.sql._typing import ColumnOrName, DataTypeOrString, UserDefinedFunctionLike
|
||||
from pyspark.sql.column import Column
|
||||
from pyspark.sql.types import DataType
|
||||
import pyspark.sql.session
|
||||
|
||||
class UserDefinedFunction:
|
||||
|
@ -35,7 +36,7 @@ class UserDefinedFunction:
|
|||
deterministic: bool = ...,
|
||||
) -> None: ...
|
||||
@property
|
||||
def returnType(self): ...
|
||||
def returnType(self) -> DataType: ...
|
||||
def __call__(self, *cols: ColumnOrName) -> Column: ...
|
||||
def asNondeterministic(self) -> UserDefinedFunction: ...
|
||||
|
||||
|
@ -47,7 +48,7 @@ class UDFRegistration:
|
|||
name: str,
|
||||
f: Callable[..., Any],
|
||||
returnType: Optional[DataTypeOrString] = ...,
|
||||
): ...
|
||||
) -> UserDefinedFunctionLike: ...
|
||||
def registerJavaFunction(
|
||||
self,
|
||||
name: str,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
from py4j.java_gateway import JavaObject # type: ignore[import]
|
||||
|
||||
|
|
|
@ -30,9 +30,12 @@ from typing import (
|
|||
)
|
||||
import datetime
|
||||
from pyspark.rdd import RDD
|
||||
import pyspark.serializers
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
import pyspark.streaming.context
|
||||
|
||||
from py4j.java_gateway import JavaObject
|
||||
|
||||
S = TypeVar("S")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
@ -42,7 +45,12 @@ V = TypeVar("V")
|
|||
class DStream(Generic[T]):
|
||||
is_cached: bool
|
||||
is_checkpointed: bool
|
||||
def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
jdstream: JavaObject,
|
||||
ssc: pyspark.streaming.context.StreamingContext,
|
||||
jrdd_deserializer: pyspark.serializers.Serializer,
|
||||
) -> None: ...
|
||||
def context(self) -> pyspark.streaming.context.StreamingContext: ...
|
||||
def count(self) -> DStream[int]: ...
|
||||
def filter(self, f: Callable[[T], bool]) -> DStream[T]: ...
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Callable, Optional, TypeVar
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.streaming.context import StreamingContext
|
||||
from pyspark.streaming.dstream import DStream
|
||||
|
|
Loading…
Reference in a new issue