[SPARK-14479][ML] GLM supports output link prediction

## What changes were proposed in this pull request?
GLM supports output link prediction.
## How was this patch tested?
unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12287 from yanboliang/spark-14479.
This commit is contained in:
Yanbo Liang 2016-04-21 17:31:33 -07:00 committed by Xiangrui Meng
parent f25a3ea8d3
commit 4e726227a3
2 changed files with 108 additions and 34 deletions

View file

@ -78,6 +78,20 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
def getLink: String = $(link)
/**
* Param for link prediction (linear predictor) column name.
* Default is empty, which means we do not output link prediction.
* @group param
*/
@Since("2.0.0")
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
"link prediction (linear predictor) column name")
setDefault(linkPredictionCol, "")
/** @group getParam */
@Since("2.0.0")
def getLinkPredictionCol: String = $(linkPredictionCol)
import GeneralizedLinearRegression._
@Since("2.0.0")
@ -93,7 +107,12 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
s"with ${$(family)} family does not support ${$(link)} link function.")
}
super.validateAndTransformSchema(schema, fitting, featuresDataType)
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if ($(linkPredictionCol).nonEmpty) {
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
} else {
newSchema
}
}
}
@ -196,6 +215,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")
/**
* Sets the link prediction (linear predictor) column name.
* @group setParam
*/
@Since("2.0.0")
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)
override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
@ -666,6 +692,13 @@ class GeneralizedLinearRegressionModel private[ml] (
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase with MLWritable {
/**
* Sets the link prediction (linear predictor) column name.
* @group setParam
*/
@Since("2.0.0")
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)
import GeneralizedLinearRegression._
lazy val familyObj = Family.fromName($(family))
@ -677,10 +710,35 @@ class GeneralizedLinearRegressionModel private[ml] (
lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)
override protected def predict(features: Vector): Double = {
val eta = BLAS.dot(features, coefficients) + intercept
val eta = predictLink(features)
familyAndLink.fitted(eta)
}
/**
* Calculate the link prediction (linear predictor) of the given instance.
*/
private def predictLink(features: Vector): Double = {
BLAS.dot(features, coefficients) + intercept
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
transformImpl(dataset)
}
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictLinkUDF = udf { (features: Vector) => predictLink(features) }
var output = dataset
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
if ($(linkPredictionCol).nonEmpty) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
}
output.toDF
}
private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None
/**

View file

@ -247,20 +247,24 @@ class GeneralizedLinearRegressionSuite
("inverse", datasetGaussianInverse))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.")
}
idx += 1
}
@ -358,7 +362,7 @@ class GeneralizedLinearRegressionSuite
("cloglog", datasetBinomial))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
model.coefficients(2), model.coefficients(3))
@ -366,13 +370,17 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"binomial family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"binomial family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.")
}
idx += 1
}
@ -427,20 +435,24 @@ class GeneralizedLinearRegressionSuite
("sqrt", datasetPoissonSqrt))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"poisson family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"poisson family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.")
}
idx += 1
}
@ -495,20 +507,24 @@ class GeneralizedLinearRegressionSuite
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gamma family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gamma family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.")
}
idx += 1
}