[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods.

Related to issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC).

Author: Michael Giannakopoulos <miccagiann@gmail.com>

Closes #1624 from miccagiann/new-branch and squashes the following commits:

c02e5f5 [Michael Giannakopoulos] Merge cleanly with upstream/master.
8dcb888 [Michael Giannakopoulos] Putting the if/else if statements in brackets.
fed8eaa [Michael Giannakopoulos] Adding a space in the message related to the IllegalArgumentException.
44e6ff0 [Michael Giannakopoulos] Adding a blank line before python class LinearRegressionWithSGD.
8eba9c5 [Michael Giannakopoulos] Change function signatures. Exception is thrown from the scala component and not from the python one.
638be47 [Michael Giannakopoulos] Modified code to comply with code standards.
ec50ee9 [Michael Giannakopoulos] Shorten the if-elif-else statement in regression.py file
b962744 [Michael Giannakopoulos] Replaced the enum classes, with strings-keywords for defining the values of 'regType' parameter.
78853ec [Michael Giannakopoulos] Providing intercept and regualizer functionallity for linear methods in only one function.
3ac8874 [Michael Giannakopoulos] Added support for regularizer and intercection parameters for linear regression method.
This commit is contained in:
Michael Giannakopoulos 2014-08-01 21:00:31 -07:00 committed by Xiangrui Meng
parent f6a1899306
commit c281189222
2 changed files with 49 additions and 11 deletions

View file

@ -23,6 +23,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
@ -252,15 +254,27 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
initialWeightsBA: Array[Byte],
regParam: Double,
regType: String,
intercept: Boolean): java.util.List[java.lang.Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
lrAlg.optimizer.setUpdater(new L1Updater)
} else if (regType != "none") {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
(data, initialWeights) =>
LinearRegressionWithSGD.train(
data,
numIterations,
stepSize,
miniBatchFraction,
initialWeights),
lrAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}

View file

@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase):
class LinearRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a linear regression model on the given data."""
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
initialWeights=None, regParam=1.0, regType=None, intercept=False):
"""
Train a linear regression model on the given data.
@param data: The training data.
@param iterations: The number of iterations (default: 100).
@param step: The step parameter used in SGD
(default: 1.0).
@param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
@param initialWeights: The initial weights (default: None).
@param regParam: The regularizer parameter (default: 1.0).
@param regType: The type of regularizer used for training
our model.
Allowed values: "l1" for using L1Updater,
"l2" for using
SquaredL2Updater,
"none" for no regularizer.
(default: "none")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
sc = data.context
if regType is None:
regType = "none"
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
d._jrdd, iterations, step, miniBatchFraction, i)
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights)