[SPARK-17138][ML][MLIB] Add Python API for multinomial logistic regression
## What changes were proposed in this pull request? Add Python API for multinomial logistic regression. - add `family` param in python api. - expose `coefficientMatrix` and `interceptVector` for `LogisticRegressionModel` - add python-side testcase for multinomial logistic regression - update python doc. ## How was this patch tested? existing and added doc tests. Author: WeichenXu <WeichenXu123@outlook.com> Closes #14852 from WeichenXu123/add_MLOR_python.
This commit is contained in:
parent
85b0a15754
commit
7f16affa26
|
@ -67,21 +67,34 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Logistic regression.
|
||||
Currently, this class only supports binary classification.
|
||||
This class supports multinomial logistic (softmax) and binomial logistic regression.
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
>>> from pyspark.ml.linalg import Vectors
|
||||
>>> df = sc.parallelize([
|
||||
>>> bdf = sc.parallelize([
|
||||
... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
|
||||
... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
|
||||
>>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
|
||||
>>> model = lr.fit(df)
|
||||
>>> model.coefficients
|
||||
>>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
|
||||
>>> blorModel = blor.fit(bdf)
|
||||
>>> blorModel.coefficients
|
||||
DenseVector([5.5...])
|
||||
>>> model.intercept
|
||||
>>> blorModel.intercept
|
||||
-2.68...
|
||||
>>> mdf = sc.parallelize([
|
||||
... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
|
||||
... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])),
|
||||
... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF()
|
||||
>>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight",
|
||||
... family="multinomial")
|
||||
>>> mlorModel = mlor.fit(mdf)
|
||||
>>> print(mlorModel.coefficientMatrix)
|
||||
DenseMatrix([[-2.3...],
|
||||
[ 0.2...],
|
||||
[ 2.1... ]])
|
||||
>>> mlorModel.interceptVector
|
||||
DenseVector([2.0..., 0.8..., -2.8...])
|
||||
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
|
||||
>>> result = model.transform(test0).head()
|
||||
>>> result = blorModel.transform(test0).head()
|
||||
>>> result.prediction
|
||||
0.0
|
||||
>>> result.probability
|
||||
|
@ -89,23 +102,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
>>> result.rawPrediction
|
||||
DenseVector([8.22..., -8.22...])
|
||||
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
|
||||
>>> model.transform(test1).head().prediction
|
||||
>>> blorModel.transform(test1).head().prediction
|
||||
1.0
|
||||
>>> lr.setParams("vector")
|
||||
>>> blor.setParams("vector")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Method setParams forces keyword arguments.
|
||||
>>> lr_path = temp_path + "/lr"
|
||||
>>> lr.save(lr_path)
|
||||
>>> blor.save(lr_path)
|
||||
>>> lr2 = LogisticRegression.load(lr_path)
|
||||
>>> lr2.getMaxIter()
|
||||
5
|
||||
>>> model_path = temp_path + "/lr_model"
|
||||
>>> model.save(model_path)
|
||||
>>> blorModel.save(model_path)
|
||||
>>> model2 = LogisticRegressionModel.load(model_path)
|
||||
>>> model.coefficients[0] == model2.coefficients[0]
|
||||
>>> blorModel.coefficients[0] == model2.coefficients[0]
|
||||
True
|
||||
>>> model.intercept == model2.intercept
|
||||
>>> blorModel.intercept == model2.intercept
|
||||
True
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
@ -117,24 +130,29 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
"e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
|
||||
family = Param(Params._dummy(), "family",
|
||||
"The name of family which is a description of the label distribution to " +
|
||||
"be used in the model. Supported options: auto, binomial, multinomial",
|
||||
typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability",
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
|
||||
aggregationDepth=2):
|
||||
aggregationDepth=2, family="auto"):
|
||||
"""
|
||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
|
||||
aggregationDepth=2)
|
||||
aggregationDepth=2, family="auto")
|
||||
If the threshold and thresholds Params are both set, they must be equivalent.
|
||||
"""
|
||||
super(LogisticRegression, self).__init__()
|
||||
self._java_obj = self._new_java_obj(
|
||||
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
|
||||
self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5)
|
||||
self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto")
|
||||
kwargs = self.__init__._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
self._checkThresholdConsistency()
|
||||
|
@ -145,13 +163,13 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability",
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
|
||||
aggregationDepth=2):
|
||||
aggregationDepth=2, family="auto"):
|
||||
"""
|
||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
|
||||
aggregationDepth=2)
|
||||
aggregationDepth=2, family="auto")
|
||||
Sets params for logistic regression.
|
||||
If the threshold and thresholds Params are both set, they must be equivalent.
|
||||
"""
|
||||
|
@ -232,6 +250,20 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
|
||||
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
|
||||
|
||||
@since("2.1.0")
|
||||
def setFamily(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`family`.
|
||||
"""
|
||||
return self._set(family=value)
|
||||
|
||||
@since("2.1.0")
|
||||
def getFamily(self):
|
||||
"""
|
||||
Gets the value of :py:attr:`family` or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.family)
|
||||
|
||||
|
||||
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
|
@ -244,7 +276,8 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
|
|||
@since("2.0.0")
|
||||
def coefficients(self):
|
||||
"""
|
||||
Model coefficients.
|
||||
Model coefficients of binomial logistic regression.
|
||||
An exception is thrown in the case of multinomial logistic regression.
|
||||
"""
|
||||
return self._call_java("coefficients")
|
||||
|
||||
|
@ -252,10 +285,27 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
|
|||
@since("1.4.0")
|
||||
def intercept(self):
|
||||
"""
|
||||
Model intercept.
|
||||
Model intercept of binomial logistic regression.
|
||||
An exception is thrown in the case of multinomial logistic regression.
|
||||
"""
|
||||
return self._call_java("intercept")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def coefficientMatrix(self):
|
||||
"""
|
||||
Model coefficients.
|
||||
"""
|
||||
return self._call_java("coefficientMatrix")
|
||||
|
||||
@property
|
||||
@since("2.1.0")
|
||||
def interceptVector(self):
|
||||
"""
|
||||
Model intercept.
|
||||
"""
|
||||
return self._call_java("interceptVector")
|
||||
|
||||
@property
|
||||
@since("2.0.0")
|
||||
def summary(self):
|
||||
|
|
Loading…
Reference in a new issue