[SPARK-12811][ML] Estimator for Generalized Linear Models(GLMs)
Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #11136 from yanboliang/spark-12811.
This commit is contained in:
parent
c43899a04e
commit
5ed48dd84d
|
@ -156,6 +156,12 @@ private[ml] class WeightedLeastSquares(
|
|||
|
||||
private[ml] object WeightedLeastSquares {
|
||||
|
||||
/**
|
||||
* In order to take the normal equation approach efficiently, [[WeightedLeastSquares]]
|
||||
* only supports the number of features is no more than 4096.
|
||||
*/
|
||||
val MAX_NUM_FEATURES: Int = 4096
|
||||
|
||||
/**
|
||||
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
|
||||
*/
|
||||
|
@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares {
|
|||
private var aaSum: DenseVector = _
|
||||
|
||||
private def init(k: Int): Unit = {
|
||||
require(k <= 4096, "In order to take the normal equation approach efficiently, " +
|
||||
s"we set the max number of features to 4096 but got $k.")
|
||||
require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " +
|
||||
s"we set the max number of features to $MAX_NUM_FEATURES but got $k.")
|
||||
this.k = k
|
||||
triK = k * (k + 1) / 2
|
||||
count = 0L
|
||||
|
|
|
@ -0,0 +1,577 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.regression
|
||||
|
||||
import breeze.stats.distributions.{Gaussian => GD}
|
||||
|
||||
import org.apache.spark.{Logging, SparkException}
|
||||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.ml.PredictorParams
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.optim._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.mllib.linalg.{BLAS, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
/**
|
||||
* Params for Generalized Linear Regression.
|
||||
*/
|
||||
private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams
|
||||
with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
|
||||
with HasSolver with Logging {
|
||||
|
||||
/**
|
||||
* Param for the name of family which is a description of the error distribution
|
||||
* to be used in the model.
|
||||
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
|
||||
* Default is "gaussian".
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
final val family: Param[String] = new Param(this, "family",
|
||||
"The name of family which is a description of the error distribution to be used in the " +
|
||||
"model. Supported options: gaussian(default), binomial, poisson and gamma.",
|
||||
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.0.0")
|
||||
def getFamily: String = $(family)
|
||||
|
||||
/**
|
||||
* Param for the name of link function which provides the relationship
|
||||
* between the linear predictor and the mean of the distribution function.
|
||||
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
final val link: Param[String] = new Param(this, "link", "The name of link function " +
|
||||
"which provides the relationship between the linear predictor and the mean of the " +
|
||||
"distribution function. Supported options: identity, log, inverse, logit, probit, " +
|
||||
"cloglog and sqrt.",
|
||||
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.0.0")
|
||||
def getLink: String = $(link)
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
@Since("2.0.0")
|
||||
override def validateParams(): Unit = {
|
||||
if ($(solver) == "irls") {
|
||||
setDefault(maxIter -> 25)
|
||||
}
|
||||
if (isDefined(link)) {
|
||||
require(supportedFamilyAndLinkPairs.contains(
|
||||
Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
|
||||
s"with ${$(family)} family does not support ${$(link)} link function.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
*
|
||||
* Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]])
|
||||
* specified by giving a symbolic description of the linear predictor (link function) and
|
||||
* a description of the error distribution (family).
|
||||
* It supports "gaussian", "binomial", "poisson" and "gamma" as family.
|
||||
* Valid link functions for each family is listed below. The first link function of each family
|
||||
* is the default one.
|
||||
* - "gaussian" -> "identity", "log", "inverse"
|
||||
* - "binomial" -> "logit", "probit", "cloglog"
|
||||
* - "poisson" -> "log", "identity", "sqrt"
|
||||
* - "gamma" -> "inverse", "identity", "log"
|
||||
*/
|
||||
@Experimental
|
||||
@Since("2.0.0")
|
||||
class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
|
||||
extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
|
||||
with GeneralizedLinearRegressionBase with Logging {
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
@Since("2.0.0")
|
||||
def this() = this(Identifiable.randomUID("glm"))
|
||||
|
||||
/**
|
||||
* Sets the value of param [[family]].
|
||||
* Default is "gaussian".
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setFamily(value: String): this.type = set(family, value)
|
||||
setDefault(family -> Gaussian.name)
|
||||
|
||||
/**
|
||||
* Sets the value of param [[link]].
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setLink(value: String): this.type = set(link, value)
|
||||
|
||||
/**
|
||||
* Sets if we should fit the intercept.
|
||||
* Default is true.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
|
||||
|
||||
/**
|
||||
* Sets the maximum number of iterations.
|
||||
* Default is 25 if the solver algorithm is "irls".
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setMaxIter(value: Int): this.type = set(maxIter, value)
|
||||
|
||||
/**
|
||||
* Sets the convergence tolerance of iterations.
|
||||
* Smaller value will lead to higher accuracy with the cost of more iterations.
|
||||
* Default is 1E-6.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setTol(value: Double): this.type = set(tol, value)
|
||||
setDefault(tol -> 1E-6)
|
||||
|
||||
/**
|
||||
* Sets the regularization parameter.
|
||||
* Default is 0.0.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setRegParam(value: Double): this.type = set(regParam, value)
|
||||
setDefault(regParam -> 0.0)
|
||||
|
||||
/**
|
||||
* Sets the value of param [[weightCol]].
|
||||
* If this is not set or empty, we treat all instance weights as 1.0.
|
||||
* Default is empty, so all instances have weight one.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
setDefault(weightCol -> "")
|
||||
|
||||
/**
|
||||
* Sets the solver algorithm used for optimization.
|
||||
* Currently only support "irls" which is also the default solver.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setSolver(value: String): this.type = set(solver, value)
|
||||
setDefault(solver -> "irls")
|
||||
|
||||
override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
|
||||
val familyObj = Family.fromName($(family))
|
||||
val linkObj = if (isDefined(link)) {
|
||||
Link.fromName($(link))
|
||||
} else {
|
||||
familyObj.defaultLink
|
||||
}
|
||||
val familyAndLink = new FamilyAndLink(familyObj, linkObj)
|
||||
|
||||
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
|
||||
.map { case Row(features: Vector) =>
|
||||
features.size
|
||||
}.first()
|
||||
if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
|
||||
val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
|
||||
s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
|
||||
throw new SparkException(msg)
|
||||
}
|
||||
|
||||
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
|
||||
.map { case Row(label: Double, weight: Double, features: Vector) =>
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
|
||||
if (familyObj == Gaussian && linkObj == Identity) {
|
||||
// TODO: Make standardizeFeatures and standardizeLabel configurable.
|
||||
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
|
||||
standardizeFeatures = true, standardizeLabel = true)
|
||||
val wlsModel = optimizer.fit(instances)
|
||||
val model = copyValues(
|
||||
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
|
||||
.setParent(this))
|
||||
return model
|
||||
}
|
||||
|
||||
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
|
||||
val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
|
||||
val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc,
|
||||
$(fitIntercept), $(regParam), $(maxIter), $(tol))
|
||||
val irlsModel = optimizer.fit(instances)
|
||||
|
||||
val model = copyValues(
|
||||
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
|
||||
.setParent(this))
|
||||
model
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
private[ml] object GeneralizedLinearRegression {
|
||||
|
||||
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
|
||||
lazy val supportedFamilyAndLinkPairs = Set(
|
||||
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
|
||||
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
|
||||
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
|
||||
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
|
||||
)
|
||||
|
||||
/** Set of family names that GeneralizedLinearRegression supports. */
|
||||
lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
|
||||
|
||||
/** Set of link names that GeneralizedLinearRegression supports. */
|
||||
lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
|
||||
|
||||
val epsilon: Double = 1E-16
|
||||
|
||||
/**
|
||||
* Wrapper of family and link combination used in the model.
|
||||
*/
|
||||
private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
|
||||
|
||||
/** Linear predictor based on given mu. */
|
||||
def predict(mu: Double): Double = link.link(family.project(mu))
|
||||
|
||||
/** Fitted value based on linear predictor eta. */
|
||||
def fitted(eta: Double): Double = family.project(link.unlink(eta))
|
||||
|
||||
/**
|
||||
* Get the initial guess model for [[IterativelyReweightedLeastSquares]].
|
||||
*/
|
||||
def initialize(
|
||||
instances: RDD[Instance],
|
||||
fitIntercept: Boolean,
|
||||
regParam: Double): WeightedLeastSquaresModel = {
|
||||
val newInstances = instances.map { instance =>
|
||||
val mu = family.initialize(instance.label, instance.weight)
|
||||
val eta = predict(mu)
|
||||
Instance(eta, instance.weight, instance.features)
|
||||
}
|
||||
// TODO: Make standardizeFeatures and standardizeLabel configurable.
|
||||
val initialModel = new WeightedLeastSquares(fitIntercept, regParam,
|
||||
standardizeFeatures = true, standardizeLabel = true)
|
||||
.fit(newInstances)
|
||||
initialModel
|
||||
}
|
||||
|
||||
/**
|
||||
* The reweight function used to update offsets and weights
|
||||
* at each iteration of [[IterativelyReweightedLeastSquares]].
|
||||
*/
|
||||
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = {
|
||||
(instance: Instance, model: WeightedLeastSquaresModel) => {
|
||||
val eta = model.predict(instance.features)
|
||||
val mu = fitted(eta)
|
||||
val offset = eta + (instance.label - mu) * link.deriv(mu)
|
||||
val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
|
||||
(offset, weight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A description of the error distribution to be used in the model.
|
||||
* @param name the name of the family.
|
||||
*/
|
||||
private[ml] abstract class Family(val name: String) extends Serializable {
|
||||
|
||||
/** The default link instance of this family. */
|
||||
val defaultLink: Link
|
||||
|
||||
/** Initialize the starting value for mu. */
|
||||
def initialize(y: Double, weight: Double): Double
|
||||
|
||||
/** The variance of the endogenous variable's mean, given the value mu. */
|
||||
def variance(mu: Double): Double
|
||||
|
||||
/** Trim the fitted value so that it will be in valid range. */
|
||||
def project(mu: Double): Double = mu
|
||||
}
|
||||
|
||||
private[ml] object Family {
|
||||
|
||||
/**
|
||||
* Gets the [[Family]] object from its name.
|
||||
* @param name family name: "gaussian", "binomial", "poisson" or "gamma".
|
||||
*/
|
||||
def fromName(name: String): Family = {
|
||||
name match {
|
||||
case Gaussian.name => Gaussian
|
||||
case Binomial.name => Binomial
|
||||
case Poisson.name => Poisson
|
||||
case Gamma.name => Gamma
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gaussian exponential family distribution.
|
||||
* The default link for the Gaussian family is the identity link.
|
||||
*/
|
||||
private[ml] object Gaussian extends Family("gaussian") {
|
||||
|
||||
val defaultLink: Link = Identity
|
||||
|
||||
override def initialize(y: Double, weight: Double): Double = y
|
||||
|
||||
def variance(mu: Double): Double = 1.0
|
||||
|
||||
override def project(mu: Double): Double = {
|
||||
if (mu.isNegInfinity) {
|
||||
Double.MinValue
|
||||
} else if (mu.isPosInfinity) {
|
||||
Double.MaxValue
|
||||
} else {
|
||||
mu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Binomial exponential family distribution.
|
||||
* The default link for the Binomial family is the logit link.
|
||||
*/
|
||||
private[ml] object Binomial extends Family("binomial") {
|
||||
|
||||
val defaultLink: Link = Logit
|
||||
|
||||
override def initialize(y: Double, weight: Double): Double = {
|
||||
val mu = (weight * y + 0.5) / (weight + 1.0)
|
||||
require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" +
|
||||
s"should be in range (0, 1), but got $mu")
|
||||
mu
|
||||
}
|
||||
|
||||
override def variance(mu: Double): Double = mu * (1.0 - mu)
|
||||
|
||||
override def project(mu: Double): Double = {
|
||||
if (mu < epsilon) {
|
||||
epsilon
|
||||
} else if (mu > 1.0 - epsilon) {
|
||||
1.0 - epsilon
|
||||
} else {
|
||||
mu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Poisson exponential family distribution.
|
||||
* The default link for the Poisson family is the log link.
|
||||
*/
|
||||
private[ml] object Poisson extends Family("poisson") {
|
||||
|
||||
val defaultLink: Link = Log
|
||||
|
||||
override def initialize(y: Double, weight: Double): Double = {
|
||||
require(y > 0.0, "The response variable of Poisson family " +
|
||||
s"should be positive, but got $y")
|
||||
y
|
||||
}
|
||||
|
||||
override def variance(mu: Double): Double = mu
|
||||
|
||||
override def project(mu: Double): Double = {
|
||||
if (mu < epsilon) {
|
||||
epsilon
|
||||
} else if (mu.isInfinity) {
|
||||
Double.MaxValue
|
||||
} else {
|
||||
mu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gamma exponential family distribution.
|
||||
* The default link for the Gamma family is the inverse link.
|
||||
*/
|
||||
private[ml] object Gamma extends Family("gamma") {
|
||||
|
||||
val defaultLink: Link = Inverse
|
||||
|
||||
override def initialize(y: Double, weight: Double): Double = {
|
||||
require(y > 0.0, "The response variable of Gamma family " +
|
||||
s"should be positive, but got $y")
|
||||
y
|
||||
}
|
||||
|
||||
override def variance(mu: Double): Double = math.pow(mu, 2.0)
|
||||
|
||||
override def project(mu: Double): Double = {
|
||||
if (mu < epsilon) {
|
||||
epsilon
|
||||
} else if (mu.isInfinity) {
|
||||
Double.MaxValue
|
||||
} else {
|
||||
mu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A description of the link function to be used in the model.
|
||||
* The link function provides the relationship between the linear predictor
|
||||
* and the mean of the distribution function.
|
||||
* @param name the name of link function.
|
||||
*/
|
||||
private[ml] abstract class Link(val name: String) extends Serializable {
|
||||
|
||||
/** The link function. */
|
||||
def link(mu: Double): Double
|
||||
|
||||
/** Derivative of the link function. */
|
||||
def deriv(mu: Double): Double
|
||||
|
||||
/** The inverse link function. */
|
||||
def unlink(eta: Double): Double
|
||||
}
|
||||
|
||||
private[ml] object Link {
|
||||
|
||||
/**
|
||||
* Gets the [[Link]] object from its name.
|
||||
* @param name link name: "identity", "logit", "log",
|
||||
* "inverse", "probit", "cloglog" or "sqrt".
|
||||
*/
|
||||
def fromName(name: String): Link = {
|
||||
name match {
|
||||
case Identity.name => Identity
|
||||
case Logit.name => Logit
|
||||
case Log.name => Log
|
||||
case Inverse.name => Inverse
|
||||
case Probit.name => Probit
|
||||
case CLogLog.name => CLogLog
|
||||
case Sqrt.name => Sqrt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[ml] object Identity extends Link("identity") {
|
||||
|
||||
override def link(mu: Double): Double = mu
|
||||
|
||||
override def deriv(mu: Double): Double = 1.0
|
||||
|
||||
override def unlink(eta: Double): Double = eta
|
||||
}
|
||||
|
||||
private[ml] object Logit extends Link("logit") {
|
||||
|
||||
override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
|
||||
|
||||
override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu))
|
||||
|
||||
override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
|
||||
}
|
||||
|
||||
private[ml] object Log extends Link("log") {
|
||||
|
||||
override def link(mu: Double): Double = math.log(mu)
|
||||
|
||||
override def deriv(mu: Double): Double = 1.0 / mu
|
||||
|
||||
override def unlink(eta: Double): Double = math.exp(eta)
|
||||
}
|
||||
|
||||
private[ml] object Inverse extends Link("inverse") {
|
||||
|
||||
override def link(mu: Double): Double = 1.0 / mu
|
||||
|
||||
override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0)
|
||||
|
||||
override def unlink(eta: Double): Double = 1.0 / eta
|
||||
}
|
||||
|
||||
private[ml] object Probit extends Link("probit") {
|
||||
|
||||
override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu)
|
||||
|
||||
override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu))
|
||||
|
||||
override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta)
|
||||
}
|
||||
|
||||
private[ml] object CLogLog extends Link("cloglog") {
|
||||
|
||||
override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
|
||||
|
||||
override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu))
|
||||
|
||||
override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
|
||||
}
|
||||
|
||||
private[ml] object Sqrt extends Link("sqrt") {
|
||||
|
||||
override def link(mu: Double): Double = math.sqrt(mu)
|
||||
|
||||
override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu))
|
||||
|
||||
override def unlink(eta: Double): Double = math.pow(eta, 2.0)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Model produced by [[GeneralizedLinearRegression]].
|
||||
*/
|
||||
@Experimental
|
||||
@Since("2.0.0")
|
||||
class GeneralizedLinearRegressionModel private[ml] (
|
||||
@Since("2.0.0") override val uid: String,
|
||||
@Since("2.0.0") val coefficients: Vector,
|
||||
@Since("2.0.0") val intercept: Double)
|
||||
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
|
||||
with GeneralizedLinearRegressionBase {
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
lazy val familyObj = Family.fromName($(family))
|
||||
lazy val linkObj = if (isDefined(link)) {
|
||||
Link.fromName($(link))
|
||||
} else {
|
||||
familyObj.defaultLink
|
||||
}
|
||||
lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)
|
||||
|
||||
override protected def predict(features: Vector): Double = {
|
||||
val eta = BLAS.dot(features, coefficients) + intercept
|
||||
familyAndLink.fitted(eta)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
|
||||
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
|
||||
.setParent(parent)
|
||||
}
|
||||
}
|
|
@ -163,8 +163,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
}.first()
|
||||
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
|
||||
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) ||
|
||||
$(solver) == "normal") {
|
||||
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
|
||||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
|
||||
require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
|
||||
"solver is used.'")
|
||||
// For low dimensional data, WeightedLeastSquares is more efficiently since the
|
||||
|
|
|
@ -0,0 +1,507 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.regression
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.MLTestingUtils
|
||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
||||
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
|
||||
import org.apache.spark.mllib.random._
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
||||
class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
private val seed: Int = 42
|
||||
@transient var datasetGaussianIdentity: DataFrame = _
|
||||
@transient var datasetGaussianLog: DataFrame = _
|
||||
@transient var datasetGaussianInverse: DataFrame = _
|
||||
@transient var datasetBinomial: DataFrame = _
|
||||
@transient var datasetPoissonLog: DataFrame = _
|
||||
@transient var datasetPoissonIdentity: DataFrame = _
|
||||
@transient var datasetPoissonSqrt: DataFrame = _
|
||||
@transient var datasetGammaInverse: DataFrame = _
|
||||
@transient var datasetGammaIdentity: DataFrame = _
|
||||
@transient var datasetGammaLog: DataFrame = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
|
||||
import GeneralizedLinearRegressionSuite._
|
||||
|
||||
datasetGaussianIdentity = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "gaussian", link = "identity"), 2))
|
||||
|
||||
datasetGaussianLog = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "gaussian", link = "log"), 2))
|
||||
|
||||
datasetGaussianInverse = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "gaussian", link = "inverse"), 2))
|
||||
|
||||
datasetBinomial = {
|
||||
val nPoints = 10000
|
||||
val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
|
||||
val xMean = Array(5.843, 3.057, 3.758, 1.199)
|
||||
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
|
||||
|
||||
val testData =
|
||||
generateMultinomialLogisticInput(coefficients, xMean, xVariance,
|
||||
addIntercept = true, nPoints, seed)
|
||||
|
||||
sqlContext.createDataFrame(sc.parallelize(testData, 2))
|
||||
}
|
||||
|
||||
datasetPoissonLog = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "poisson", link = "log"), 2))
|
||||
|
||||
datasetPoissonIdentity = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "poisson", link = "identity"), 2))
|
||||
|
||||
datasetPoissonSqrt = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "poisson", link = "sqrt"), 2))
|
||||
|
||||
datasetGammaInverse = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "gamma", link = "inverse"), 2))
|
||||
|
||||
datasetGammaIdentity = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "gamma", link = "identity"), 2))
|
||||
|
||||
datasetGammaLog = sqlContext.createDataFrame(
|
||||
sc.parallelize(generateGeneralizedLinearRegressionInput(
|
||||
intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
|
||||
family = "gamma", link = "log"), 2))
|
||||
}
|
||||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new GeneralizedLinearRegression)
|
||||
val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
test("generalized linear regression: default params") {
|
||||
val glr = new GeneralizedLinearRegression
|
||||
assert(glr.getLabelCol === "label")
|
||||
assert(glr.getFeaturesCol === "features")
|
||||
assert(glr.getPredictionCol === "prediction")
|
||||
assert(glr.getFitIntercept)
|
||||
assert(glr.getTol === 1E-6)
|
||||
assert(glr.getWeightCol === "")
|
||||
assert(glr.getRegParam === 0.0)
|
||||
assert(glr.getSolver == "irls")
|
||||
// TODO: Construct model directly instead of via fitting.
|
||||
val model = glr.setFamily("gaussian").setLink("identity")
|
||||
.fit(datasetGaussianIdentity)
|
||||
|
||||
// copied model must have the same parent.
|
||||
MLTestingUtils.checkCopy(model)
|
||||
|
||||
assert(model.getFeaturesCol === "features")
|
||||
assert(model.getPredictionCol === "prediction")
|
||||
assert(model.intercept !== 0.0)
|
||||
assert(model.hasParent)
|
||||
assert(model.getFamily === "gaussian")
|
||||
assert(model.getLink === "identity")
|
||||
}
|
||||
|
||||
test("generalized linear regression: gaussian family against glm") {
|
||||
/*
|
||||
R code:
|
||||
f1 <- data$V1 ~ data$V2 + data$V3 - 1
|
||||
f2 <- data$V1 ~ data$V2 + data$V3
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family="gaussian", data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 2.2960999 0.8087933
|
||||
[1] 2.5002642 2.2000403 0.5999485
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0))
|
||||
model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0))
|
||||
print(as.vector(coef(model1)))
|
||||
print(as.vector(coef(model2)))
|
||||
|
||||
[1] 0.23069326 0.07993778
|
||||
[1] 0.25001858 0.22002452 0.05998789
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=gaussian(link=inverse), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 2.3010179 0.8198976
|
||||
[1] 2.4108902 2.2130248 0.6086152
|
||||
*/
|
||||
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, 2.2960999, 0.8087933),
|
||||
Vectors.dense(2.5002642, 2.2000403, 0.5999485),
|
||||
Vectors.dense(0.0, 0.23069326, 0.07993778),
|
||||
Vectors.dense(0.25001858, 0.22002452, 0.05998789),
|
||||
Vectors.dense(0.0, 2.3010179, 0.8198976),
|
||||
Vectors.dense(2.4108902, 2.2130248, 0.6086152))
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
var idx = 0
|
||||
for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog),
|
||||
("inverse", datasetGaussianInverse))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("generalized linear regression: gaussian family against glmnet") {
|
||||
/*
|
||||
R code:
|
||||
library(glmnet)
|
||||
data <- read.csv("path", header=FALSE)
|
||||
label = data$V1
|
||||
features = as.matrix(data.frame(data$V2, data$V3))
|
||||
for (intercept in c(FALSE, TRUE)) {
|
||||
for (lambda in c(0.0, 0.1, 1.0)) {
|
||||
model <- glmnet(features, label, family="gaussian", intercept=intercept,
|
||||
lambda=lambda, alpha=0, thresh=1E-14)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
}
|
||||
|
||||
[1] 0.0000000 2.2961005 0.8087932
|
||||
[1] 0.0000000 2.2130368 0.8309556
|
||||
[1] 0.0000000 1.7176137 0.9610657
|
||||
[1] 2.5002642 2.2000403 0.5999485
|
||||
[1] 3.1106389 2.0935142 0.5712711
|
||||
[1] 6.7597127 1.4581054 0.3994266
|
||||
*/
|
||||
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, 2.2961005, 0.8087932),
|
||||
Vectors.dense(0.0, 2.2130368, 0.8309556),
|
||||
Vectors.dense(0.0, 1.7176137, 0.9610657),
|
||||
Vectors.dense(2.5002642, 2.2000403, 0.5999485),
|
||||
Vectors.dense(3.1106389, 2.0935142, 0.5712711),
|
||||
Vectors.dense(6.7597127, 1.4581054, 0.3994266))
|
||||
|
||||
var idx = 0
|
||||
for (fitIntercept <- Seq(false, true);
|
||||
regParam <- Seq(0.0, 0.1, 1.0)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("gaussian")
|
||||
.setFitIntercept(fitIntercept).setRegParam(regParam)
|
||||
val model = trainer.fit(datasetGaussianIdentity)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
|
||||
s"fitIntercept = $fitIntercept and regParam = $regParam.")
|
||||
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
|
||||
test("generalized linear regression: binomial family against glm") {
|
||||
/*
|
||||
R code:
|
||||
f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1
|
||||
f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5
|
||||
data <- read.csv("path", header=FALSE)
|
||||
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family="binomial", data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] -0.3560284 1.3010002 -0.3570805 -0.7406762
|
||||
[1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989
|
||||
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=binomial(link=probit), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] -0.2134390 0.7800646 -0.2144267 -0.4438358
|
||||
[1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850
|
||||
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=binomial(link=cloglog), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] -0.2832198 0.8434144 -0.2524727 -0.5293452
|
||||
[1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758
|
||||
*/
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762),
|
||||
Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989),
|
||||
Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358),
|
||||
Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850),
|
||||
Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452),
|
||||
Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758))
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
var idx = 0
|
||||
for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial),
|
||||
("cloglog", datasetBinomial))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
|
||||
model.coefficients(2), model.coefficients(3))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"binomial family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("generalized linear regression: poisson family against glm") {
|
||||
/*
|
||||
R code:
|
||||
f1 <- data$V1 ~ data$V2 + data$V3 - 1
|
||||
f2 <- data$V1 ~ data$V2 + data$V3
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family="poisson", data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 0.22999393 0.08047088
|
||||
[1] 0.25022353 0.21998599 0.05998621
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=poisson(link=identity), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 2.2929501 0.8119415
|
||||
[1] 2.5012730 2.1999407 0.5999107
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=poisson(link=sqrt), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 2.2958947 0.8090515
|
||||
[1] 2.5000480 2.1999972 0.5999968
|
||||
*/
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, 0.22999393, 0.08047088),
|
||||
Vectors.dense(0.25022353, 0.21998599, 0.05998621),
|
||||
Vectors.dense(0.0, 2.2929501, 0.8119415),
|
||||
Vectors.dense(2.5012730, 2.1999407, 0.5999107),
|
||||
Vectors.dense(0.0, 2.2958947, 0.8090515),
|
||||
Vectors.dense(2.5000480, 2.1999972, 0.5999968))
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
var idx = 0
|
||||
for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity),
|
||||
("sqrt", datasetPoissonSqrt))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"poisson family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("generalized linear regression: gamma family against glm") {
|
||||
/*
|
||||
R code:
|
||||
f1 <- data$V1 ~ data$V2 + data$V3 - 1
|
||||
f2 <- data$V1 ~ data$V2 + data$V3
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family="Gamma", data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 2.3392419 0.8058058
|
||||
[1] 2.3507700 2.2533574 0.6042991
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=Gamma(link=identity), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 2.2908883 0.8147796
|
||||
[1] 2.5002406 2.1998346 0.6000059
|
||||
|
||||
data <- read.csv("path", header=FALSE)
|
||||
for (formula in c(f1, f2)) {
|
||||
model <- glm(formula, family=Gamma(link=log), data=data)
|
||||
print(as.vector(coef(model)))
|
||||
}
|
||||
|
||||
[1] 0.22958970 0.08091066
|
||||
[1] 0.25003210 0.21996957 0.06000215
|
||||
*/
|
||||
val expected = Seq(
|
||||
Vectors.dense(0.0, 2.3392419, 0.8058058),
|
||||
Vectors.dense(2.3507700, 2.2533574, 0.6042991),
|
||||
Vectors.dense(0.0, 2.2908883, 0.8147796),
|
||||
Vectors.dense(2.5002406, 2.1998346, 0.6000059),
|
||||
Vectors.dense(0.0, 0.22958970, 0.08091066),
|
||||
Vectors.dense(0.25003210, 0.21996957, 0.06000215))
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
var idx = 0
|
||||
for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
|
||||
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"gamma family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object GeneralizedLinearRegressionSuite {
|
||||
|
||||
def generateGeneralizedLinearRegressionInput(
|
||||
intercept: Double,
|
||||
coefficients: Array[Double],
|
||||
xMean: Array[Double],
|
||||
xVariance: Array[Double],
|
||||
nPoints: Int,
|
||||
seed: Int,
|
||||
noiseLevel: Double,
|
||||
family: String,
|
||||
link: String): Seq[LabeledPoint] = {
|
||||
|
||||
val rnd = new Random(seed)
|
||||
def rndElement(i: Int) = {
|
||||
(rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
|
||||
}
|
||||
val (generator, mean) = family match {
|
||||
case "gaussian" => (new StandardNormalGenerator, 0.0)
|
||||
case "poisson" => (new PoissonGenerator(1.0), 1.0)
|
||||
case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0)
|
||||
}
|
||||
generator.setSeed(seed)
|
||||
|
||||
(0 until nPoints).map { _ =>
|
||||
val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray)
|
||||
val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept
|
||||
val mu = link match {
|
||||
case "identity" => eta
|
||||
case "log" => math.exp(eta)
|
||||
case "sqrt" => math.pow(eta, 2.0)
|
||||
case "inverse" => 1.0 / eta
|
||||
}
|
||||
val label = mu + noiseLevel * (generator.nextValue() - mean)
|
||||
// Return LabeledPoints with DenseVector
|
||||
LabeledPoint(label, features)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue