[SPARK-19400][ML] Allow GLM to handle intercept only model
## What changes were proposed in this pull request? Intercept-only GLM is failing for non-Gaussian family because of reducing an empty array in IWLS. The following code `val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => math.max(math.abs(x), math.abs(y))` fails in the intercept-only model because `oldCoefficients` is empty. This PR fixes this issue. yanboliang srowen imatiach-msft zhengruifeng ## How was this patch tested? New test for intercept only model. Author: actuaryzhang <actuaryzhang10@gmail.com> Closes #16740 from actuaryzhang/interceptOnly.
This commit is contained in:
parent
15627ac743
commit
1aeb9f6cba
|
@ -89,7 +89,7 @@ private[ml] class IterativelyReweightedLeastSquares(
|
|||
val oldCoefficients = oldModel.coefficients
|
||||
val coefficients = model.coefficients
|
||||
BLAS.axpy(-1.0, coefficients, oldCoefficients)
|
||||
val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
|
||||
val maxTolOfCoefficients = oldCoefficients.toArray.foldLeft(0.0) { (x, y) =>
|
||||
math.max(math.abs(x), math.abs(y))
|
||||
}
|
||||
val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept))
|
||||
|
|
|
@ -335,6 +335,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
throw new SparkException(msg)
|
||||
}
|
||||
|
||||
require(numFeatures > 0 || $(fitIntercept),
|
||||
"GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
|
||||
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )
|
||||
|
||||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val instances: RDD[Instance] =
|
||||
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
|
||||
|
|
|
@ -743,6 +743,61 @@ class GeneralizedLinearRegressionSuite
|
|||
}
|
||||
}
|
||||
|
||||
test("generalized linear regression: intercept only") {
|
||||
/*
|
||||
R code:
|
||||
|
||||
library(statmod)
|
||||
y <- c(1.0, 0.5, 0.7, 0.3)
|
||||
w <- c(1, 2, 3, 4)
|
||||
for (fam in list(gaussian(), poisson(), binomial(), Gamma(), tweedie(1.6))) {
|
||||
model1 <- glm(y ~ 1, family = fam)
|
||||
model2 <- glm(y ~ 1, family = fam, weights = w)
|
||||
print(as.vector(c(coef(model1), coef(model2))))
|
||||
}
|
||||
[1] 0.625 0.530
|
||||
[1] -0.4700036 -0.6348783
|
||||
[1] 0.5108256 0.1201443
|
||||
[1] 1.600000 1.886792
|
||||
[1] 1.325782 1.463641
|
||||
*/
|
||||
|
||||
val dataset = Seq(
|
||||
Instance(1.0, 1.0, Vectors.zeros(0)),
|
||||
Instance(0.5, 2.0, Vectors.zeros(0)),
|
||||
Instance(0.7, 3.0, Vectors.zeros(0)),
|
||||
Instance(0.3, 4.0, Vectors.zeros(0))
|
||||
).toDF()
|
||||
|
||||
val expected = Seq(0.625, 0.530, -0.4700036, -0.6348783, 0.5108256, 0.1201443,
|
||||
1.600000, 1.886792, 1.325782, 1.463641)
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
var idx = 0
|
||||
for (family <- Seq("gaussian", "poisson", "binomial", "gamma", "tweedie")) {
|
||||
for (useWeight <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily(family)
|
||||
if (useWeight) trainer.setWeightCol("weight")
|
||||
if (family == "tweedie") trainer.setVariancePower(1.6)
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = model.intercept
|
||||
assert(actual ~== expected(idx) absTol 1E-3, "Model mismatch: intercept only GLM with " +
|
||||
s"useWeight = $useWeight and family = $family.")
|
||||
assert(model.coefficients === new DenseVector(Array.empty[Double]))
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
|
||||
// throw exception for empty model
|
||||
val trainer = new GeneralizedLinearRegression().setFitIntercept(false)
|
||||
withClue("Specified model is empty with neither intercept nor feature") {
|
||||
intercept[IllegalArgumentException] {
|
||||
trainer.fit(dataset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("glm summary: gaussian family with weight") {
|
||||
/*
|
||||
R code:
|
||||
|
|
Loading…
Reference in a new issue