[SPARK-13597][PYSPARK][ML] Python API for GeneralizedLinearRegression
## What changes were proposed in this pull request? Python API for GeneralizedLinearRegression JIRA: https://issues.apache.org/jira/browse/SPARK-13597 ## How was this patch tested? The patch is tested with Python doctest. Author: Kai Jiang <jiangkai@gmail.com> Closes #11468 from vectorijk/spark-13597.
This commit is contained in:
parent
101663f1ae
commit
7f024c4744
|
@ -28,6 +28,7 @@ from pyspark.sql import DataFrame
|
|||
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
|
||||
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
|
||||
'GBTRegressor', 'GBTRegressionModel',
|
||||
'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel'
|
||||
'IsotonicRegression', 'IsotonicRegressionModel',
|
||||
'LinearRegression', 'LinearRegressionModel',
|
||||
'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
|
||||
|
@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|||
return self._call_java("predict", features)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol,
|
||||
HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
|
||||
HasSolver, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Generalized Linear Regression.
|
||||
|
||||
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
|
||||
predictor (link function) and a description of the error distribution (family). It supports
|
||||
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
|
||||
is listed below. The first link function of each family is the default one.
|
||||
- "gaussian" -> "identity", "log", "inverse"
|
||||
- "binomial" -> "logit", "probit", "cloglog"
|
||||
- "poisson" -> "log", "identity", "sqrt"
|
||||
- "gamma" -> "inverse", "identity", "log"
|
||||
|
||||
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
|
||||
|
||||
>>> from pyspark.mllib.linalg import Vectors
|
||||
>>> df = sqlContext.createDataFrame([
|
||||
... (1.0, Vectors.dense(0.0, 0.0)),
|
||||
... (1.0, Vectors.dense(1.0, 2.0)),
|
||||
... (2.0, Vectors.dense(0.0, 0.0)),
|
||||
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
|
||||
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
|
||||
>>> model = glr.fit(df)
|
||||
>>> abs(model.transform(df).head().prediction - 1.5) < 0.001
|
||||
True
|
||||
>>> model.coefficients
|
||||
DenseVector([1.5..., -1.0...])
|
||||
>>> abs(model.intercept - 1.5) < 0.001
|
||||
True
|
||||
>>> glr_path = temp_path + "/glr"
|
||||
>>> glr.save(glr_path)
|
||||
>>> glr2 = GeneralizedLinearRegression.load(glr_path)
|
||||
>>> glr.getFamily() == glr2.getFamily()
|
||||
True
|
||||
>>> model_path = temp_path + "/glr_model"
|
||||
>>> model.save(model_path)
|
||||
>>> model2 = GeneralizedLinearRegressionModel.load(model_path)
|
||||
>>> model.intercept == model2.intercept
|
||||
True
|
||||
>>> model.coefficients[0] == model2.coefficients[0]
|
||||
True
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
"""
|
||||
|
||||
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
|
||||
"the error distribution to be used in the model. Supported options: " +
|
||||
"gaussian(default), binomial, poisson and gamma.")
|
||||
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
|
||||
"relationship between the linear predictor and the mean of the distribution " +
|
||||
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
|
||||
"and sqrt.")
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
|
||||
regParam=0.0, weightCol=None, solver="irls"):
|
||||
"""
|
||||
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
|
||||
regParam=0.0, weightCol=None, solver="irls")
|
||||
"""
|
||||
super(GeneralizedLinearRegression, self).__init__()
|
||||
self._java_obj = self._new_java_obj(
|
||||
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
|
||||
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
|
||||
kwargs = self.__init__._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
@since("2.0.0")
|
||||
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
|
||||
regParam=0.0, weightCol=None, solver="irls"):
|
||||
"""
|
||||
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
|
||||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
|
||||
regParam=0.0, weightCol=None, solver="irls")
|
||||
Sets params for generalized linear regression.
|
||||
"""
|
||||
kwargs = self.setParams._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
||||
def _create_model(self, java_model):
|
||||
return GeneralizedLinearRegressionModel(java_model)
|
||||
|
||||
@since("2.0.0")
|
||||
def setFamily(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`family`.
|
||||
"""
|
||||
self._paramMap[self.family] = value
|
||||
return self
|
||||
|
||||
@since("2.0.0")
|
||||
def getFamily(self):
|
||||
"""
|
||||
Gets the value of family or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.family)
|
||||
|
||||
@since("2.0.0")
|
||||
def setLink(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`link`.
|
||||
"""
|
||||
self._paramMap[self.link] = value
|
||||
return self
|
||||
|
||||
@since("2.0.0")
|
||||
def getLink(self):
|
||||
"""
|
||||
Gets the value of link or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.link)
|
||||
|
||||
|
||||
class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by GeneralizedLinearRegression.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
"""
|
||||
|
||||
@property
|
||||
@since("2.0.0")
|
||||
def coefficients(self):
|
||||
"""
|
||||
Model coefficients.
|
||||
"""
|
||||
return self._call_java("coefficients")
|
||||
|
||||
@property
|
||||
@since("2.0.0")
|
||||
def intercept(self):
|
||||
"""
|
||||
Model intercept.
|
||||
"""
|
||||
return self._call_java("intercept")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
import pyspark.ml.regression
|
||||
|
|
Loading…
Reference in a new issue