[SPARK-6255] [MLLIB] Support multiclass classification in Python API
Python API parity check for classification and multiclass classification support, major disparities need to be added for Python: ```scala LogisticRegressionWithLBFGS setNumClasses setValidateData LogisticRegressionModel getThreshold numClasses numFeatures SVMWithSGD setValidateData SVMModel getThreshold ``` For users the greatest benefit in this PR is multiclass classification was supported by Python API. Users can train multiclass classification model and use it to predict in pyspark. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5137 from yanboliang/spark-6255 and squashes the following commits: 0bd531e [Yanbo Liang] address comments 444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization fc7990b [Yanbo Liang] address comments b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API ded847c [Yanbo Liang] Python API parity check for classification (support multiclass classification)
This commit is contained in:
parent
46de6c05e0
commit
b5bd75d90a
|
@ -77,7 +77,13 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
initialWeights: Vector): JList[Object] = {
|
||||
try {
|
||||
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
|
||||
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
|
||||
if (model.isInstanceOf[LogisticRegressionModel]) {
|
||||
val lrModel = model.asInstanceOf[LogisticRegressionModel]
|
||||
List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
|
||||
.map(_.asInstanceOf[Object]).asJava
|
||||
} else {
|
||||
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
|
||||
}
|
||||
} finally {
|
||||
data.rdd.unpersist(blocking = false)
|
||||
}
|
||||
|
@ -190,9 +196,11 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
miniBatchFraction: Double,
|
||||
initialWeights: Vector,
|
||||
regType: String,
|
||||
intercept: Boolean): JList[Object] = {
|
||||
intercept: Boolean,
|
||||
validateData: Boolean): JList[Object] = {
|
||||
val SVMAlg = new SVMWithSGD()
|
||||
SVMAlg.setIntercept(intercept)
|
||||
.setValidateData(validateData)
|
||||
SVMAlg.optimizer
|
||||
.setNumIterations(numIterations)
|
||||
.setRegParam(regParam)
|
||||
|
@ -216,9 +224,11 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
initialWeights: Vector,
|
||||
regParam: Double,
|
||||
regType: String,
|
||||
intercept: Boolean): JList[Object] = {
|
||||
intercept: Boolean,
|
||||
validateData: Boolean): JList[Object] = {
|
||||
val LogRegAlg = new LogisticRegressionWithSGD()
|
||||
LogRegAlg.setIntercept(intercept)
|
||||
.setValidateData(validateData)
|
||||
LogRegAlg.optimizer
|
||||
.setNumIterations(numIterations)
|
||||
.setRegParam(regParam)
|
||||
|
@ -242,9 +252,13 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
regType: String,
|
||||
intercept: Boolean,
|
||||
corrections: Int,
|
||||
tolerance: Double): JList[Object] = {
|
||||
tolerance: Double,
|
||||
validateData: Boolean,
|
||||
numClasses: Int): JList[Object] = {
|
||||
val LogRegAlg = new LogisticRegressionWithLBFGS()
|
||||
LogRegAlg.setIntercept(intercept)
|
||||
.setValidateData(validateData)
|
||||
.setNumClasses(numClasses)
|
||||
LogRegAlg.optimizer
|
||||
.setNumIterations(numIterations)
|
||||
.setRegParam(regParam)
|
||||
|
|
|
@ -22,7 +22,7 @@ from numpy import array
|
|||
|
||||
from pyspark import RDD
|
||||
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
|
||||
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
|
||||
from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
|
||||
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
|
||||
from pyspark.mllib.util import Saveable, Loader, inherit_doc
|
||||
|
||||
|
@ -31,13 +31,13 @@ __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegr
|
|||
'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
|
||||
|
||||
|
||||
class LinearBinaryClassificationModel(LinearModel):
|
||||
class LinearClassificationModel(LinearModel):
|
||||
"""
|
||||
Represents a linear binary classification model that predicts to whether an
|
||||
example is positive (1.0) or negative (0.0).
|
||||
A private abstract class representing a multiclass classification model.
|
||||
The categories are represented by int values: 0, 1, 2, etc.
|
||||
"""
|
||||
def __init__(self, weights, intercept):
|
||||
super(LinearBinaryClassificationModel, self).__init__(weights, intercept)
|
||||
super(LinearClassificationModel, self).__init__(weights, intercept)
|
||||
self._threshold = None
|
||||
|
||||
def setThreshold(self, value):
|
||||
|
@ -47,14 +47,26 @@ class LinearBinaryClassificationModel(LinearModel):
|
|||
Sets the threshold that separates positive predictions from negative
|
||||
predictions. An example with prediction score greater than or equal
|
||||
to this threshold is identified as an positive, and negative otherwise.
|
||||
It is used for binary classification only.
|
||||
"""
|
||||
self._threshold = value
|
||||
|
||||
@property
|
||||
def threshold(self):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
Returns the threshold (if any) used for converting raw prediction scores
|
||||
into 0/1 predictions. It is used for binary classification only.
|
||||
"""
|
||||
return self._threshold
|
||||
|
||||
def clearThreshold(self):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
Clears the threshold so that `predict` will output raw prediction scores.
|
||||
It is used for binary classification only.
|
||||
"""
|
||||
self._threshold = None
|
||||
|
||||
|
@ -66,7 +78,7 @@ class LinearBinaryClassificationModel(LinearModel):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class LogisticRegressionModel(LinearBinaryClassificationModel):
|
||||
class LogisticRegressionModel(LinearClassificationModel):
|
||||
|
||||
"""A linear binary classification model derived from logistic regression.
|
||||
|
||||
|
@ -112,10 +124,39 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
|
|||
... os.removedirs(path)
|
||||
... except:
|
||||
... pass
|
||||
>>> multi_class_data = [
|
||||
... LabeledPoint(0.0, [0.0, 1.0, 0.0]),
|
||||
... LabeledPoint(1.0, [1.0, 0.0, 0.0]),
|
||||
... LabeledPoint(2.0, [0.0, 0.0, 1.0])
|
||||
... ]
|
||||
>>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3)
|
||||
>>> mcm.predict([0.0, 0.5, 0.0])
|
||||
0
|
||||
>>> mcm.predict([0.8, 0.0, 0.0])
|
||||
1
|
||||
>>> mcm.predict([0.0, 0.0, 0.3])
|
||||
2
|
||||
"""
|
||||
def __init__(self, weights, intercept):
|
||||
def __init__(self, weights, intercept, numFeatures, numClasses):
|
||||
super(LogisticRegressionModel, self).__init__(weights, intercept)
|
||||
self._numFeatures = int(numFeatures)
|
||||
self._numClasses = int(numClasses)
|
||||
self._threshold = 0.5
|
||||
if self._numClasses == 2:
|
||||
self._dataWithBiasSize = None
|
||||
self._weightsMatrix = None
|
||||
else:
|
||||
self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1)
|
||||
self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1,
|
||||
self._dataWithBiasSize)
|
||||
|
||||
@property
|
||||
def numFeatures(self):
|
||||
return self._numFeatures
|
||||
|
||||
@property
|
||||
def numClasses(self):
|
||||
return self._numClasses
|
||||
|
||||
def predict(self, x):
|
||||
"""
|
||||
|
@ -126,20 +167,38 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
|
|||
return x.map(lambda v: self.predict(v))
|
||||
|
||||
x = _convert_to_vector(x)
|
||||
margin = self.weights.dot(x) + self._intercept
|
||||
if margin > 0:
|
||||
prob = 1 / (1 + exp(-margin))
|
||||
if self.numClasses == 2:
|
||||
margin = self.weights.dot(x) + self._intercept
|
||||
if margin > 0:
|
||||
prob = 1 / (1 + exp(-margin))
|
||||
else:
|
||||
exp_margin = exp(margin)
|
||||
prob = exp_margin / (1 + exp_margin)
|
||||
if self._threshold is None:
|
||||
return prob
|
||||
else:
|
||||
return 1 if prob > self._threshold else 0
|
||||
else:
|
||||
exp_margin = exp(margin)
|
||||
prob = exp_margin / (1 + exp_margin)
|
||||
if self._threshold is None:
|
||||
return prob
|
||||
else:
|
||||
return 1 if prob > self._threshold else 0
|
||||
best_class = 0
|
||||
max_margin = 0.0
|
||||
if x.size + 1 == self._dataWithBiasSize:
|
||||
for i in range(0, self._numClasses - 1):
|
||||
margin = x.dot(self._weightsMatrix[i][0:x.size]) + \
|
||||
self._weightsMatrix[i][x.size]
|
||||
if margin > max_margin:
|
||||
max_margin = margin
|
||||
best_class = i + 1
|
||||
else:
|
||||
for i in range(0, self._numClasses - 1):
|
||||
margin = x.dot(self._weightsMatrix[i])
|
||||
if margin > max_margin:
|
||||
max_margin = margin
|
||||
best_class = i + 1
|
||||
return best_class
|
||||
|
||||
def save(self, sc, path):
|
||||
java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
|
||||
_py2java(sc, self._coeff), self.intercept)
|
||||
_py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses)
|
||||
java_model.save(sc._jsc.sc(), path)
|
||||
|
||||
@classmethod
|
||||
|
@ -148,8 +207,10 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
|
|||
sc._jsc.sc(), path)
|
||||
weights = _java2py(sc, java_model.weights())
|
||||
intercept = java_model.intercept()
|
||||
numFeatures = java_model.numFeatures()
|
||||
numClasses = java_model.numClasses()
|
||||
threshold = java_model.getThreshold().get()
|
||||
model = LogisticRegressionModel(weights, intercept)
|
||||
model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses)
|
||||
model.setThreshold(threshold)
|
||||
return model
|
||||
|
||||
|
@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object):
|
|||
|
||||
@classmethod
|
||||
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
|
||||
initialWeights=None, regParam=0.01, regType="l2", intercept=False):
|
||||
initialWeights=None, regParam=0.01, regType="l2", intercept=False,
|
||||
validateData=True):
|
||||
"""
|
||||
Train a logistic regression model on the given data.
|
||||
|
||||
|
@ -184,11 +246,14 @@ class LogisticRegressionWithSGD(object):
|
|||
or not of the augmented representation for
|
||||
training data (i.e. whether bias features
|
||||
are activated or not).
|
||||
:param validateData: Boolean parameter which indicates if the
|
||||
algorithm should validate data before training.
|
||||
(default: True)
|
||||
"""
|
||||
def train(rdd, i):
|
||||
return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
|
||||
float(step), float(miniBatchFraction), i, float(regParam), regType,
|
||||
bool(intercept))
|
||||
bool(intercept), bool(validateData))
|
||||
|
||||
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
|
||||
|
||||
|
@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object):
|
|||
|
||||
@classmethod
|
||||
def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
|
||||
intercept=False, corrections=10, tolerance=1e-4):
|
||||
intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2):
|
||||
"""
|
||||
Train a logistic regression model on the given data.
|
||||
|
||||
|
@ -223,6 +288,11 @@ class LogisticRegressionWithLBFGS(object):
|
|||
update (default: 10).
|
||||
:param tolerance: The convergence tolerance of iterations for
|
||||
L-BFGS (default: 1e-4).
|
||||
:param validateData: Boolean parameter which indicates if the
|
||||
algorithm should validate data before training.
|
||||
(default: True)
|
||||
:param numClasses: The number of classes (i.e., outcomes) a label can take
|
||||
in Multinomial Logistic Regression (default: 2).
|
||||
|
||||
>>> data = [
|
||||
... LabeledPoint(0.0, [0.0, 1.0]),
|
||||
|
@ -237,12 +307,20 @@ class LogisticRegressionWithLBFGS(object):
|
|||
def train(rdd, i):
|
||||
return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i,
|
||||
float(regParam), regType, bool(intercept), int(corrections),
|
||||
float(tolerance))
|
||||
float(tolerance), bool(validateData), int(numClasses))
|
||||
|
||||
if initialWeights is None:
|
||||
if numClasses == 2:
|
||||
initialWeights = [0.0] * len(data.first().features)
|
||||
else:
|
||||
if intercept:
|
||||
initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1)
|
||||
else:
|
||||
initialWeights = [0.0] * len(data.first().features) * (numClasses - 1)
|
||||
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
|
||||
|
||||
|
||||
class SVMModel(LinearBinaryClassificationModel):
|
||||
class SVMModel(LinearClassificationModel):
|
||||
|
||||
"""A support vector machine.
|
||||
|
||||
|
@ -325,7 +403,8 @@ class SVMWithSGD(object):
|
|||
|
||||
@classmethod
|
||||
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
|
||||
miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False):
|
||||
miniBatchFraction=1.0, initialWeights=None, regType="l2",
|
||||
intercept=False, validateData=True):
|
||||
"""
|
||||
Train a support vector machine on the given data.
|
||||
|
||||
|
@ -351,11 +430,14 @@ class SVMWithSGD(object):
|
|||
or not of the augmented representation for
|
||||
training data (i.e. whether bias features
|
||||
are activated or not).
|
||||
:param validateData: Boolean parameter which indicates if the
|
||||
algorithm should validate data before training.
|
||||
(default: True)
|
||||
"""
|
||||
def train(rdd, i):
|
||||
return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
|
||||
float(regParam), float(miniBatchFraction), i, regType,
|
||||
bool(intercept))
|
||||
bool(intercept), bool(validateData))
|
||||
|
||||
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
|
||||
|
||||
|
|
|
@ -167,13 +167,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
|
|||
# return the result of a call to the appropriate JVM stub.
|
||||
# _regression_train_wrapper is responsible for setup and error checking.
|
||||
def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
|
||||
from pyspark.mllib.classification import LogisticRegressionModel
|
||||
first = data.first()
|
||||
if not isinstance(first, LabeledPoint):
|
||||
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
|
||||
if initial_weights is None:
|
||||
initial_weights = [0.0] * len(data.first().features)
|
||||
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
|
||||
return modelClass(weights, intercept)
|
||||
if (modelClass == LogisticRegressionModel):
|
||||
weights, intercept, numFeatures, numClasses = train_func(
|
||||
data, _convert_to_vector(initial_weights))
|
||||
return modelClass(weights, intercept, numFeatures, numClasses)
|
||||
else:
|
||||
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
|
||||
return modelClass(weights, intercept)
|
||||
|
||||
|
||||
class LinearRegressionWithSGD(object):
|
||||
|
|
Loading…
Reference in a new issue