[SPARK-15316][PYSPARK][ML] Add linkPredictionCol to GeneralizedLinearRegression
## What changes were proposed in this pull request? Add linkPredictionCol to GeneralizedLinearRegression and fix the PyDoc to generate the bullet list ## How was this patch tested? doctests & built docs locally Author: Holden Karau <holden@us.ibm.com> Closes #13106 from holdenk/SPARK-15316-add-linkPredictionCol-toGeneralizedLinearRegression.
This commit is contained in:
parent
f5065abf49
commit
e71cd96bf7
|
@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
predictor (link function) and a description of the error distribution (family). It supports
|
||||
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
|
||||
is listed below. The first link function of each family is the default one.
|
||||
- "gaussian" -> "identity", "log", "inverse"
|
||||
- "binomial" -> "logit", "probit", "cloglog"
|
||||
- "poisson" -> "log", "identity", "sqrt"
|
||||
- "gamma" -> "inverse", "identity", "log"
|
||||
|
||||
* "gaussian" -> "identity", "log", "inverse"
|
||||
|
||||
* "binomial" -> "logit", "probit", "cloglog"
|
||||
|
||||
* "poisson" -> "log", "identity", "sqrt"
|
||||
|
||||
* "gamma" -> "inverse", "identity", "log"
|
||||
|
||||
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
|
||||
|
||||
|
@ -1258,9 +1262,12 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
... (1.0, Vectors.dense(1.0, 2.0)),
|
||||
... (2.0, Vectors.dense(0.0, 0.0)),
|
||||
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
|
||||
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
|
||||
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
|
||||
>>> model = glr.fit(df)
|
||||
>>> abs(model.transform(df).head().prediction - 1.5) < 0.001
|
||||
>>> transformed = model.transform(df)
|
||||
>>> abs(transformed.head().prediction - 1.5) < 0.001
|
||||
True
|
||||
>>> abs(transformed.head().p - 1.5) < 0.001
|
||||
True
|
||||
>>> model.coefficients
|
||||
DenseVector([1.5..., -1.0...])
|
||||
|
@ -1290,20 +1297,23 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
"relationship between the linear predictor and the mean of the distribution " +
|
||||
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
|
||||
"and sqrt.", typeConverter=TypeConverters.toString)
|
||||
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
|
||||
"predictor) column name", typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
|
||||
regParam=0.0, weightCol=None, solver="irls"):
|
||||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
|
||||
"""
|
||||
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
|
||||
regParam=0.0, weightCol=None, solver="irls")
|
||||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
|
||||
"""
|
||||
super(GeneralizedLinearRegression, self).__init__()
|
||||
self._java_obj = self._new_java_obj(
|
||||
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
|
||||
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
|
||||
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
|
||||
linkPredictionCol="")
|
||||
kwargs = self.__init__._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
|
@ -1311,11 +1321,11 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
@since("2.0.0")
|
||||
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
|
||||
regParam=0.0, weightCol=None, solver="irls"):
|
||||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
|
||||
"""
|
||||
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
|
||||
regParam=0.0, weightCol=None, solver="irls")
|
||||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
|
||||
Sets params for generalized linear regression.
|
||||
"""
|
||||
kwargs = self.setParams._input_kwargs
|
||||
|
@ -1338,6 +1348,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
"""
|
||||
return self.getOrDefault(self.family)
|
||||
|
||||
@since("2.0.0")
|
||||
def setLinkPredictionCol(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`linkPredictionCol`.
|
||||
"""
|
||||
return self._set(linkPredictionCol=value)
|
||||
|
||||
@since("2.0.0")
|
||||
def getLinkPredictionCol(self):
|
||||
"""
|
||||
Gets the value of linkPredictionCol or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.linkPredictionCol)
|
||||
|
||||
@since("2.0.0")
|
||||
def setLink(self, value):
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue