[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods.
Related to issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC). Author: Michael Giannakopoulos <miccagiann@gmail.com> Closes #1624 from miccagiann/new-branch and squashes the following commits: c02e5f5 [Michael Giannakopoulos] Merge cleanly with upstream/master. 8dcb888 [Michael Giannakopoulos] Putting the if/else if statements in brackets. fed8eaa [Michael Giannakopoulos] Adding a space in the message related to the IllegalArgumentException. 44e6ff0 [Michael Giannakopoulos] Adding a blank line before python class LinearRegressionWithSGD. 8eba9c5 [Michael Giannakopoulos] Change function signatures. Exception is thrown from the scala component and not from the python one. 638be47 [Michael Giannakopoulos] Modified code to comply with code standards. ec50ee9 [Michael Giannakopoulos] Shorten the if-elif-else statement in regression.py file b962744 [Michael Giannakopoulos] Replaced the enum classes, with strings-keywords for defining the values of 'regType' parameter. 78853ec [Michael Giannakopoulos] Providing intercept and regualizer functionallity for linear methods in only one function. 3ac8874 [Michael Giannakopoulos] Added support for regularizer and intercection parameters for linear regression method.
This commit is contained in:
parent
f6a1899306
commit
c281189222
|
@ -23,6 +23,8 @@ import org.apache.spark.annotation.DeveloperApi
|
|||
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
|
||||
import org.apache.spark.mllib.classification._
|
||||
import org.apache.spark.mllib.clustering._
|
||||
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.mllib.optimization._
|
||||
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
|
||||
import org.apache.spark.mllib.recommendation._
|
||||
|
@ -252,15 +254,27 @@ class PythonMLLibAPI extends Serializable {
|
|||
numIterations: Int,
|
||||
stepSize: Double,
|
||||
miniBatchFraction: Double,
|
||||
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
|
||||
initialWeightsBA: Array[Byte],
|
||||
regParam: Double,
|
||||
regType: String,
|
||||
intercept: Boolean): java.util.List[java.lang.Object] = {
|
||||
val lrAlg = new LinearRegressionWithSGD()
|
||||
lrAlg.setIntercept(intercept)
|
||||
lrAlg.optimizer
|
||||
.setNumIterations(numIterations)
|
||||
.setRegParam(regParam)
|
||||
.setStepSize(stepSize)
|
||||
if (regType == "l2") {
|
||||
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
|
||||
} else if (regType == "l1") {
|
||||
lrAlg.optimizer.setUpdater(new L1Updater)
|
||||
} else if (regType != "none") {
|
||||
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
|
||||
+ " Can only be initialized using the following string values: [l1, l2, none].")
|
||||
}
|
||||
trainRegressionModel(
|
||||
(data, initialWeights) =>
|
||||
LinearRegressionWithSGD.train(
|
||||
data,
|
||||
numIterations,
|
||||
stepSize,
|
||||
miniBatchFraction,
|
||||
initialWeights),
|
||||
lrAlg.run(data, initialWeights),
|
||||
dataBytesJRDD,
|
||||
initialWeightsBA)
|
||||
}
|
||||
|
|
|
@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase):
|
|||
|
||||
class LinearRegressionWithSGD(object):
|
||||
@classmethod
|
||||
def train(cls, data, iterations=100, step=1.0,
|
||||
miniBatchFraction=1.0, initialWeights=None):
|
||||
"""Train a linear regression model on the given data."""
|
||||
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
|
||||
initialWeights=None, regParam=1.0, regType=None, intercept=False):
|
||||
"""
|
||||
Train a linear regression model on the given data.
|
||||
|
||||
@param data: The training data.
|
||||
@param iterations: The number of iterations (default: 100).
|
||||
@param step: The step parameter used in SGD
|
||||
(default: 1.0).
|
||||
@param miniBatchFraction: Fraction of data to be used for each SGD
|
||||
iteration.
|
||||
@param initialWeights: The initial weights (default: None).
|
||||
@param regParam: The regularizer parameter (default: 1.0).
|
||||
@param regType: The type of regularizer used for training
|
||||
our model.
|
||||
Allowed values: "l1" for using L1Updater,
|
||||
"l2" for using
|
||||
SquaredL2Updater,
|
||||
"none" for no regularizer.
|
||||
(default: "none")
|
||||
@param intercept: Boolean parameter which indicates the use
|
||||
or not of the augmented representation for
|
||||
training data (i.e. whether bias features
|
||||
are activated or not).
|
||||
"""
|
||||
sc = data.context
|
||||
if regType is None:
|
||||
regType = "none"
|
||||
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
|
||||
d._jrdd, iterations, step, miniBatchFraction, i)
|
||||
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
|
||||
return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue