104 lines
3.6 KiB
Python
104 lines
3.6 KiB
Python
|
#
|
||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||
|
# or more contributor license agreements. See the NOTICE file
|
||
|
# distributed with this work for additional information
|
||
|
# regarding copyright ownership. The ASF licenses this file
|
||
|
# to you under the Apache License, Version 2.0 (the
|
||
|
# "License"); you may not use this file except in compliance
|
||
|
# with the License. You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing,
|
||
|
# software distributed under the License is distributed on an
|
||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
|
# KIND, either express or implied. See the License for the
|
||
|
# specific language governing permissions and limitations
|
||
|
# under the License.
|
||
|
|
||
|
from typing import overload
|
||
|
from typing import (
|
||
|
Callable,
|
||
|
Generic,
|
||
|
Iterable,
|
||
|
List,
|
||
|
Optional,
|
||
|
Tuple,
|
||
|
)
|
||
|
from pyspark.ml._typing import M, P, T, ParamMap
|
||
|
|
||
|
import _thread
|
||
|
|
||
|
import abc
|
||
|
from abc import abstractmethod
|
||
|
from pyspark import since as since # noqa: F401
|
||
|
from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401
|
||
|
from pyspark.ml.param.shared import (
|
||
|
HasFeaturesCol as HasFeaturesCol,
|
||
|
HasInputCol as HasInputCol,
|
||
|
HasLabelCol as HasLabelCol,
|
||
|
HasOutputCol as HasOutputCol,
|
||
|
HasPredictionCol as HasPredictionCol,
|
||
|
Params as Params,
|
||
|
)
|
||
|
from pyspark.sql.functions import udf as udf # noqa: F401
|
||
|
from pyspark.sql.types import ( # noqa: F401
|
||
|
DataType,
|
||
|
StructField as StructField,
|
||
|
StructType as StructType,
|
||
|
)
|
||
|
|
||
|
from pyspark.sql.dataframe import DataFrame
|
||
|
|
||
|
class _FitMultipleIterator:
|
||
|
fitSingleModel: Callable[[int], Transformer]
|
||
|
numModel: int
|
||
|
counter: int = ...
|
||
|
lock: _thread.LockType
|
||
|
def __init__(
|
||
|
self, fitSingleModel: Callable[[int], Transformer], numModels: int
|
||
|
) -> None: ...
|
||
|
def __iter__(self) -> _FitMultipleIterator: ...
|
||
|
def __next__(self) -> Tuple[int, Transformer]: ...
|
||
|
def next(self) -> Tuple[int, Transformer]: ...
|
||
|
|
||
|
class Estimator(Generic[M], Params, metaclass=abc.ABCMeta):
|
||
|
@overload
|
||
|
def fit(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> M: ...
|
||
|
@overload
|
||
|
def fit(self, dataset: DataFrame, params: List[ParamMap]) -> List[M]: ...
|
||
|
def fitMultiple(
|
||
|
self, dataset: DataFrame, params: List[ParamMap]
|
||
|
) -> Iterable[Tuple[int, M]]: ...
|
||
|
|
||
|
class Transformer(Params, metaclass=abc.ABCMeta):
|
||
|
def transform(
|
||
|
self, dataset: DataFrame, params: Optional[ParamMap] = ...
|
||
|
) -> DataFrame: ...
|
||
|
|
||
|
class Model(Transformer, metaclass=abc.ABCMeta): ...
|
||
|
|
||
|
class UnaryTransformer(HasInputCol, HasOutputCol, Transformer, metaclass=abc.ABCMeta):
|
||
|
def createTransformFunc(self) -> Callable: ...
|
||
|
def outputDataType(self) -> DataType: ...
|
||
|
def validateInputType(self, inputType: DataType) -> None: ...
|
||
|
def transformSchema(self, schema: StructType) -> StructType: ...
|
||
|
def setInputCol(self: M, value: str) -> M: ...
|
||
|
def setOutputCol(self: M, value: str) -> M: ...
|
||
|
|
||
|
class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol): ...
|
||
|
|
||
|
class Predictor(Estimator[M], _PredictorParams, metaclass=abc.ABCMeta):
|
||
|
def setLabelCol(self: P, value: str) -> P: ...
|
||
|
def setFeaturesCol(self: P, value: str) -> P: ...
|
||
|
def setPredictionCol(self: P, value: str) -> P: ...
|
||
|
|
||
|
class PredictionModel(Generic[T], Model, _PredictorParams, metaclass=abc.ABCMeta):
|
||
|
def setFeaturesCol(self: M, value: str) -> M: ...
|
||
|
def setPredictionCol(self: M, value: str) -> M: ...
|
||
|
@property
|
||
|
@abc.abstractmethod
|
||
|
def numFeatures(self) -> int: ...
|
||
|
@abstractmethod
|
||
|
def predict(self, value: T) -> float: ...
|