[SPARK-14479][ML] GLM supports output link prediction
## What changes were proposed in this pull request? GLM supports output link prediction. ## How was this patch tested? unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #12287 from yanboliang/spark-14479.
This commit is contained in:
parent
f25a3ea8d3
commit
4e726227a3
|
@ -78,6 +78,20 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
@Since("2.0.0")
|
||||
def getLink: String = $(link)
|
||||
|
||||
/**
|
||||
* Param for link prediction (linear predictor) column name.
|
||||
* Default is empty, which means we do not output link prediction.
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
|
||||
"link prediction (linear predictor) column name")
|
||||
setDefault(linkPredictionCol, "")
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.0.0")
|
||||
def getLinkPredictionCol: String = $(linkPredictionCol)
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
@Since("2.0.0")
|
||||
|
@ -93,7 +107,12 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
|
||||
s"with ${$(family)} family does not support ${$(link)} link function.")
|
||||
}
|
||||
super.validateAndTransformSchema(schema, fitting, featuresDataType)
|
||||
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
|
||||
if ($(linkPredictionCol).nonEmpty) {
|
||||
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
|
||||
} else {
|
||||
newSchema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -196,6 +215,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
def setSolver(value: String): this.type = set(solver, value)
|
||||
setDefault(solver -> "irls")
|
||||
|
||||
/**
|
||||
* Sets the link prediction (linear predictor) column name.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)
|
||||
|
||||
override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
|
||||
val familyObj = Family.fromName($(family))
|
||||
val linkObj = if (isDefined(link)) {
|
||||
|
@ -666,6 +692,13 @@ class GeneralizedLinearRegressionModel private[ml] (
|
|||
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
|
||||
with GeneralizedLinearRegressionBase with MLWritable {
|
||||
|
||||
/**
|
||||
* Sets the link prediction (linear predictor) column name.
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
lazy val familyObj = Family.fromName($(family))
|
||||
|
@ -677,10 +710,35 @@ class GeneralizedLinearRegressionModel private[ml] (
|
|||
lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)
|
||||
|
||||
override protected def predict(features: Vector): Double = {
|
||||
val eta = BLAS.dot(features, coefficients) + intercept
|
||||
val eta = predictLink(features)
|
||||
familyAndLink.fitted(eta)
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the link prediction (linear predictor) of the given instance.
|
||||
*/
|
||||
private def predictLink(features: Vector): Double = {
|
||||
BLAS.dot(features, coefficients) + intercept
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema)
|
||||
transformImpl(dataset)
|
||||
}
|
||||
|
||||
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
|
||||
val predictUDF = udf { (features: Vector) => predict(features) }
|
||||
val predictLinkUDF = udf { (features: Vector) => predictLink(features) }
|
||||
var output = dataset
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
||||
}
|
||||
if ($(linkPredictionCol).nonEmpty) {
|
||||
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
|
||||
}
|
||||
output.toDF
|
||||
}
|
||||
|
||||
private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None
|
||||
|
||||
/**
|
||||
|
|
|
@ -247,20 +247,24 @@ class GeneralizedLinearRegressionSuite
|
|||
("inverse", datasetGaussianInverse))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
|
||||
.foreach {
|
||||
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
val linkPrediction2 = eta
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
|
||||
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
|
||||
s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
|
@ -358,7 +362,7 @@ class GeneralizedLinearRegressionSuite
|
|||
("cloglog", datasetBinomial))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
|
||||
model.coefficients(2), model.coefficients(3))
|
||||
|
@ -366,13 +370,17 @@ class GeneralizedLinearRegressionSuite
|
|||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"binomial family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
|
||||
.foreach {
|
||||
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
val linkPrediction2 = eta
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"binomial family, $link link and fitIntercept = $fitIntercept.")
|
||||
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
|
||||
s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
|
@ -427,20 +435,24 @@ class GeneralizedLinearRegressionSuite
|
|||
("sqrt", datasetPoissonSqrt))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"poisson family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
|
||||
.foreach {
|
||||
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
val linkPrediction2 = eta
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"poisson family, $link link and fitIntercept = $fitIntercept.")
|
||||
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
|
||||
s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
|
@ -495,20 +507,24 @@ class GeneralizedLinearRegressionSuite
|
|||
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
|
||||
for (fitIntercept <- Seq(false, true)) {
|
||||
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
|
||||
val model = trainer.fit(dataset)
|
||||
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
|
||||
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
|
||||
s"$link link and fitIntercept = $fitIntercept.")
|
||||
|
||||
val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
|
||||
model.transform(dataset).select("features", "prediction").collect().foreach {
|
||||
case Row(features: DenseVector, prediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"gamma family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
|
||||
.foreach {
|
||||
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
|
||||
val eta = BLAS.dot(features, model.coefficients) + model.intercept
|
||||
val prediction2 = familyLink.fitted(eta)
|
||||
val linkPrediction2 = eta
|
||||
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
|
||||
s"gamma family, $link link and fitIntercept = $fitIntercept.")
|
||||
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
|
||||
s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.")
|
||||
}
|
||||
|
||||
idx += 1
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue