[SPARK-30378][ML][PYSPARK] Add getter/setter in Python FM
### What changes were proposed in this pull request? add getter/setter in Python FM ### Why are the changes needed? to be consistent with other algorithms ### Does this PR introduce any user-facing change? Yes. add getter/setter in Python FMRegressor/FMRegressionModel/FMClassifier/FMClassificationModel ### How was this patch tested? doctest Closes #27044 from huaxingao/spark-30378. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
parent
32a5233d12
commit
9ee8da298d
|
@ -25,7 +25,7 @@ from pyspark.ml.param.shared import *
|
|||
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
|
||||
_TreeEnsembleModel, _RandomForestParams, _GBTParams, \
|
||||
_HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
|
||||
from pyspark.ml.regression import DecisionTreeRegressionModel
|
||||
from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
|
||||
from pyspark.ml.util import *
|
||||
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
|
||||
JavaPredictor, _JavaPredictorParams, JavaPredictionModel, JavaWrapper
|
||||
|
@ -2765,8 +2765,8 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class FMClassifier(JavaProbabilisticClassifier, HasMaxIter, HasStepSize, HasTol, HasSolver,
|
||||
HasSeed, HasFitIntercept, HasRegParam, JavaMLWritable, JavaMLReadable):
|
||||
class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
Factorization Machines learning algorithm for classification.
|
||||
|
||||
|
@ -2780,8 +2780,12 @@ class FMClassifier(JavaProbabilisticClassifier, HasMaxIter, HasStepSize, HasTol,
|
|||
>>> df = spark.createDataFrame([
|
||||
... (1.0, Vectors.dense(1.0)),
|
||||
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
||||
>>> fm = FMClassifier(factorSize=2, seed=11)
|
||||
>>> fm = FMClassifier(factorSize=2)
|
||||
>>> fm.setSeed(11)
|
||||
FMClassifier...
|
||||
>>> model = fm.fit(df)
|
||||
>>> model.getMaxIter()
|
||||
100
|
||||
>>> test0 = spark.createDataFrame([
|
||||
... (Vectors.dense(-1.0),),
|
||||
... (Vectors.dense(0.5),),
|
||||
|
@ -2895,8 +2899,58 @@ class FMClassifier(JavaProbabilisticClassifier, HasMaxIter, HasStepSize, HasTol,
|
|||
"""
|
||||
return self._set(initStd=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setMaxIter(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`maxIter`.
|
||||
"""
|
||||
return self._set(maxIter=value)
|
||||
|
||||
class FMClassificationModel(JavaProbabilisticClassificationModel, JavaMLWritable, JavaMLReadable):
|
||||
@since("3.0.0")
|
||||
def setStepSize(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`stepSize`.
|
||||
"""
|
||||
return self._set(stepSize=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setTol(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`tol`.
|
||||
"""
|
||||
return self._set(tol=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setSolver(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`solver`.
|
||||
"""
|
||||
return self._set(solver=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setSeed(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`seed`.
|
||||
"""
|
||||
return self._set(seed=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setFitIntercept(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`fitIntercept`.
|
||||
"""
|
||||
return self._set(fitIntercept=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setRegParam(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`regParam`.
|
||||
"""
|
||||
return self._set(regParam=value)
|
||||
|
||||
|
||||
class FMClassificationModel(JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
|
||||
JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by :class:`FMClassifier`.
|
||||
|
||||
|
|
|
@ -2297,9 +2297,63 @@ class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSumm
|
|||
return self._call_java("toString")
|
||||
|
||||
|
||||
class _FactorizationMachinesParams(_JavaPredictorParams, HasMaxIter, HasStepSize, HasTol,
|
||||
HasSolver, HasSeed, HasFitIntercept, HasRegParam):
|
||||
"""
|
||||
Params for :py:class:`FMRegressor`, :py:class:`FMRegressionModel`, :py:class:`FMClassifier`
|
||||
and :py:class:`FMClassifierModel`.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
"""
|
||||
|
||||
factorSize = Param(Params._dummy(), "factorSize", "Dimensionality of the factor vectors, " +
|
||||
"which are used to get pairwise interactions between variables",
|
||||
typeConverter=TypeConverters.toInt)
|
||||
|
||||
fitLinear = Param(Params._dummy(), "fitLinear", "whether to fit linear term (aka 1-way term)",
|
||||
typeConverter=TypeConverters.toBoolean)
|
||||
|
||||
miniBatchFraction = Param(Params._dummy(), "miniBatchFraction", "fraction of the input data " +
|
||||
"set that should be used for one iteration of gradient descent",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
|
||||
initStd = Param(Params._dummy(), "initStd", "standard deviation of initial coefficients",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
|
||||
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
||||
"options: gd, adamW. (Default adamW)", typeConverter=TypeConverters.toString)
|
||||
|
||||
@since("3.0.0")
|
||||
def getFactorSize(self):
|
||||
"""
|
||||
Gets the value of factorSize or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.factorSize)
|
||||
|
||||
@since("3.0.0")
|
||||
def getFitLinear(self):
|
||||
"""
|
||||
Gets the value of fitLinear or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.fitLinear)
|
||||
|
||||
@since("3.0.0")
|
||||
def getMiniBatchFraction(self):
|
||||
"""
|
||||
Gets the value of miniBatchFraction or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.miniBatchFraction)
|
||||
|
||||
@since("3.0.0")
|
||||
def getInitStd(self):
|
||||
"""
|
||||
Gets the value of initStd or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.initStd)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class FMRegressor(JavaPredictor, HasMaxIter, HasStepSize, HasTol, HasSolver, HasSeed,
|
||||
HasFitIntercept, HasRegParam, JavaMLWritable, JavaMLReadable):
|
||||
class FMRegressor(JavaPredictor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Factorization Machines learning algorithm for regression.
|
||||
|
||||
|
@ -2315,8 +2369,12 @@ class FMRegressor(JavaPredictor, HasMaxIter, HasStepSize, HasTol, HasSolver, Has
|
|||
... (1.0, Vectors.dense(1.0)),
|
||||
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
||||
>>>
|
||||
>>> fm = FMRegressor(factorSize=2, seed=16)
|
||||
>>> fm = FMRegressor(factorSize=2)
|
||||
>>> fm.setSeed(16)
|
||||
FMRegressor...
|
||||
>>> model = fm.fit(df)
|
||||
>>> model.getMaxIter()
|
||||
100
|
||||
>>> test0 = spark.createDataFrame([
|
||||
... (Vectors.dense(-2.0),),
|
||||
... (Vectors.dense(0.5),),
|
||||
|
@ -2426,8 +2484,58 @@ class FMRegressor(JavaPredictor, HasMaxIter, HasStepSize, HasTol, HasSolver, Has
|
|||
"""
|
||||
return self._set(initStd=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setMaxIter(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`maxIter`.
|
||||
"""
|
||||
return self._set(maxIter=value)
|
||||
|
||||
class FMRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
||||
@since("3.0.0")
|
||||
def setStepSize(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`stepSize`.
|
||||
"""
|
||||
return self._set(stepSize=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setTol(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`tol`.
|
||||
"""
|
||||
return self._set(tol=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setSolver(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`solver`.
|
||||
"""
|
||||
return self._set(solver=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setSeed(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`seed`.
|
||||
"""
|
||||
return self._set(seed=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setFitIntercept(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`fitIntercept`.
|
||||
"""
|
||||
return self._set(fitIntercept=value)
|
||||
|
||||
@since("3.0.0")
|
||||
def setRegParam(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`regParam`.
|
||||
"""
|
||||
return self._set(regParam=value)
|
||||
|
||||
|
||||
class FMRegressionModel(JavaPredictionModel, _FactorizationMachinesParams, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
"""
|
||||
Model fitted by :class:`FMRegressor`.
|
||||
|
||||
|
|
Loading…
Reference in a new issue