[SPARK-19155][ML] Make family case insensitive in GLM
## What changes were proposed in this pull request? This is a supplement to PR #16516 which did not make the value from `getFamily` case insensitive. Current tests of poisson/binomial glm with weight fail when specifying 'Poisson' or 'Binomial', because the calculation of `dispersion` and `pValue` checks the value of family retrieved from `getFamily` ``` model.getFamily == Binomial.name || model.getFamily == Poisson.name ``` ## How was this patch tested? Update existing tests for 'Poisson' and 'Binomial'. yanboliang felixcheung imatiach-msft Author: actuaryzhang <actuaryzhang10@gmail.com> Closes #16675 from actuaryzhang/family.
This commit is contained in:
parent
de6ad3dfa7
commit
f067acefab
|
@ -1044,7 +1044,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
|
|||
*/
|
||||
@Since("2.0.0")
|
||||
lazy val dispersion: Double = if (
|
||||
model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
|
||||
model.getFamily.toLowerCase == Binomial.name ||
|
||||
model.getFamily.toLowerCase == Poisson.name) {
|
||||
1.0
|
||||
} else {
|
||||
val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0)
|
||||
|
@ -1147,7 +1148,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
|
|||
@Since("2.0.0")
|
||||
lazy val pValues: Array[Double] = {
|
||||
if (isNormalSolver) {
|
||||
if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
|
||||
if (model.getFamily.toLowerCase == Binomial.name ||
|
||||
model.getFamily.toLowerCase == Poisson.name) {
|
||||
tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
|
||||
} else {
|
||||
tValues.map { x =>
|
||||
|
|
|
@ -758,7 +758,7 @@ class GeneralizedLinearRegressionSuite
|
|||
0.028480 0.069123 0.935495 -0.049613
|
||||
*/
|
||||
val trainer = new GeneralizedLinearRegression()
|
||||
.setFamily("binomial")
|
||||
.setFamily("Binomial")
|
||||
.setWeightCol("weight")
|
||||
.setFitIntercept(false)
|
||||
|
||||
|
@ -875,7 +875,7 @@ class GeneralizedLinearRegressionSuite
|
|||
-0.4378554 0.2189277 0.1459518 -0.1094638
|
||||
*/
|
||||
val trainer = new GeneralizedLinearRegression()
|
||||
.setFamily("poisson")
|
||||
.setFamily("Poisson")
|
||||
.setWeightCol("weight")
|
||||
.setFitIntercept(true)
|
||||
|
||||
|
|
Loading…
Reference in a new issue