[SPARK-18166][MLLIB] Fix Poisson GLM bug due to wrong requirement of response values

## What changes were proposed in this pull request?

The current implementation of Poisson GLM seems to allow only positive values. This is incorrect since the support of Poisson includes the origin. The bug is easily fixed by changing the test of the Poisson variable from  'require(y **>** 0.0' to  'require(y **>=** 0.0'.

mengxr  srowen

Author: actuaryzhang <actuaryzhang10@gmail.com>
Author: actuaryzhang <actuaryzhang@uber.com>

Closes #15683 from actuaryzhang/master.
This commit is contained in:
actuaryzhang 2016-11-14 12:08:06 +01:00 committed by Sean Owen
parent f95b124c68
commit ae6cddb787
No known key found for this signature in database
GPG key ID: BEB3956D6717BDDC
2 changed files with 47 additions and 2 deletions

View file

@ -501,8 +501,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
val defaultLink: Link = Log
override def initialize(y: Double, weight: Double): Double = {
require(y > 0.0, "The response variable of Poisson family " +
s"should be positive, but got $y")
require(y >= 0.0, "The response variable of Poisson family " +
s"should be non-negative, but got $y")
y
}

View file

@ -44,6 +44,7 @@ class GeneralizedLinearRegressionSuite
@transient var datasetGaussianInverse: DataFrame = _
@transient var datasetBinomial: DataFrame = _
@transient var datasetPoissonLog: DataFrame = _
@transient var datasetPoissonLogWithZero: DataFrame = _
@transient var datasetPoissonIdentity: DataFrame = _
@transient var datasetPoissonSqrt: DataFrame = _
@transient var datasetGammaInverse: DataFrame = _
@ -88,6 +89,12 @@ class GeneralizedLinearRegressionSuite
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
family = "poisson", link = "log").toDF()
datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput(
intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01,
family = "poisson", link = "log")
.map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF()
datasetPoissonIdentity = generateGeneralizedLinearRegressionInput(
intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
@ -139,6 +146,10 @@ class GeneralizedLinearRegressionSuite
label + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile(
"target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog")
datasetPoissonLogWithZero.rdd.map { case Row(label: Double, features: Vector) =>
label + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile(
"target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLogWithZero")
datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) =>
label + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile(
@ -456,6 +467,40 @@ class GeneralizedLinearRegressionSuite
}
}
test("generalized linear regression: poisson family against glm (with zero values)") {
/*
R code:
f1 <- data$V1 ~ data$V2 + data$V3 - 1
f2 <- data$V1 ~ data$V2 + data$V3
data <- read.csv("path", header=FALSE)
for (formula in c(f1, f2)) {
model <- glm(formula, family="poisson", data=data)
print(as.vector(coef(model)))
}
[1] 0.4272661 -0.1565423
[1] -3.6911354 0.6214301 0.1295814
*/
val expected = Seq(
Vectors.dense(0.0, 0.4272661, -0.1565423),
Vectors.dense(-3.6911354, 0.6214301, 0.1295814))
import GeneralizedLinearRegression._
var idx = 0
val link = "log"
val dataset = datasetPoissonLogWithZero
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept (with zero values).")
idx += 1
}
}
test("generalized linear regression: gamma family against glm") {
/*
R code: