[SPARK-8538] [SPARK-8539] [ML] Linear Regression Training and Testing Results

Adds results (e.g. objective value at each iteration, residuals) on training and user-specified test sets for LinearRegressionModel.

Notes to Reviewers:
 * Are the `*TrainingResults` and `Results` classes too specialized for `LinearRegressionModel`? Where would be an appropriate level of abstraction?
 * Please check `transient` annotations are correct; the datasets should not be copied and kept during serialization.
 * Any thoughts on `RDD`s versus `DataFrame`s? If using `DataFrame`s, suggested schemas for each intermediate step? Also, how to create a "local DataFrame" without a `sqlContext`?

Author: Feynman Liang <fliang@databricks.com>

Closes #7099 from feynmanliang/SPARK-8538 and squashes the following commits:

d219fa4 [Feynman Liang] Update docs
4a42680 [Feynman Liang] Change Summary to hold values, move transient annotations down to metrics and predictions DF
6300031 [Feynman Liang] Code review changes
0a5e762 [Feynman Liang] Fix build error
e71102d [Feynman Liang] Merge branch 'master' into SPARK-8538
3367489 [Feynman Liang] Merge branch 'master' into SPARK-8538
70f267c [Feynman Liang] Make TrainingSummary transient and remove Serializable from *Summary and RegressionMetrics
1d9ea42 [Feynman Liang] Fix failing Java test
a65dfda [Feynman Liang] Make TrainingSummary and metrics serializable, prediction dataframe transient
0a605d8 [Feynman Liang] Replace Params from LinearRegression*Summary with private constructor vals
c2fe835 [Feynman Liang] Optimize imports
02d8a70 [Feynman Liang] Add Params to LinearModel*Summary, refactor tests and add test for evaluate()
8f999f4 [Feynman Liang] Refactor from jkbradley code review
072e948 [Feynman Liang] Style
509ae36 [Feynman Liang] Use DFs and localize serialization to LinearRegressionModel
9509c79 [Feynman Liang] Fix imports
b2bbaa3 [Feynman Liang] Refactored LinearRegressionResults API to be more private
ffceaec [Feynman Liang] Merge branch 'master' into SPARK-8538
1cedb2b [Feynman Liang] Add test for decreasing objective trace
dab0aff [Feynman Liang] Add LinearRegressionTrainingResults tests, make test suite code copy+pasteable
97b0a81 [Feynman Liang] Add LinearRegressionModel.evaluate() to get results on test sets
dc51bce [Feynman Liang] Style guide fixes
521f397 [Feynman Liang] Use RDD[(Double, Double)] instead of DF
2ff5710 [Feynman Liang] Add training results and model summary to ML LinearRegression
This commit is contained in:
Feynman Liang 2015-07-09 16:21:21 -07:00 committed by Joseph K. Bradley
parent e29ce319fa
commit a0cc3e5aa3
2 changed files with 192 additions and 6 deletions

View file

@ -22,18 +22,20 @@ import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.{SparkException, Logging}
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
@ -139,7 +141,16 @@ class LinearRegression(override val uid: String)
logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
s"and the intercept will be the mean of the label; as a result, training is not needed.")
if (handlePersistence) instances.unpersist()
return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
val weights = Vectors.sparse(numFeatures, Seq())
val intercept = yMean
val model = new LinearRegressionModel(uid, weights, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset).select($(predictionCol), $(labelCol)),
$(predictionCol),
$(labelCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
}
val featuresMean = summarizer.mean.toArray
@ -178,7 +189,6 @@ class LinearRegression(override val uid: String)
state = states.next()
arrayBuilder += state.adjustedValue
}
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
logError(msg)
@ -209,7 +219,13 @@ class LinearRegression(override val uid: String)
if (handlePersistence) instances.unpersist()
copyValues(new LinearRegressionModel(uid, weights, intercept))
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset).select($(predictionCol), $(labelCol)),
$(predictionCol),
$(labelCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
@ -227,15 +243,126 @@ class LinearRegressionModel private[ml] (
extends RegressionModel[Vector, LinearRegressionModel]
with LinearRegressionParams {
private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
/**
* Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
def summary: LinearRegressionTrainingSummary = trainingSummary match {
case Some(summ) => summ
case None =>
throw new SparkException(
"No training summary available for this LinearRegressionModel",
new NullPointerException())
}
private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
this
}
/** Indicates whether a training summary exists for this model instance. */
def hasSummary: Boolean = trainingSummary.isDefined
/**
* Evaluates the model on a testset.
* @param dataset Test dataset to evaluate model on.
*/
// TODO: decide on a good name before exposing to public API
private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
val t = udf { features: Vector => predict(features) }
val predictionAndObservations = dataset
.select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
}
override protected def predict(features: Vector): Double = {
dot(features, weights) + intercept
}
override def copy(extra: ParamMap): LinearRegressionModel = {
copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel
}
}
/**
* :: Experimental ::
* Linear regression training results.
* @param predictions predictions outputted by the model's `transform` method.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
/** Number of training iterations until termination */
val totalIterations = objectiveHistory.length
}
/**
* :: Experimental ::
* Linear regression results evaluated on a dataset.
* @param predictions predictions outputted by the model's `transform` method.
*/
@Experimental
class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
val labelCol: String) extends Serializable {
@transient private val metrics = new RegressionMetrics(
predictions
.select(predictionCol, labelCol)
.map { case Row(pred: Double, label: Double) => (pred, label) } )
/**
* Returns the explained variance regression score.
* explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
* Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
*/
val explainedVariance: Double = metrics.explainedVariance
/**
* Returns the mean absolute error, which is a risk function corresponding to the
* expected value of the absolute error loss or l1-norm loss.
*/
val meanAbsoluteError: Double = metrics.meanAbsoluteError
/**
* Returns the mean squared error, which is a risk function corresponding to the
* expected value of the squared error loss or quadratic loss.
*/
val meanSquaredError: Double = metrics.meanSquaredError
/**
* Returns the root mean squared error, which is defined as the square root of
* the mean squared error.
*/
val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
/**
* Returns R^2^, the coefficient of determination.
* Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
*/
val r2: Double = metrics.r2
/** Residuals (predicted value - label value) */
@transient lazy val residuals: DataFrame = {
val t = udf { (pred: Double, label: Double) => pred - label}
predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
}
}
/**
* LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
* as used in linear regression for samples in sparse or dense vector in a online fashion.

View file

@ -289,4 +289,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression model training summary") {
val trainer = new LinearRegression
val model = trainer.fit(dataset)
// Training results for the model should be available
assert(model.hasSummary)
// Residuals in [[LinearRegressionResults]] should equal those manually computed
val expectedResiduals = dataset.select("features", "label")
.map { case Row(features: DenseVector, label: Double) =>
val prediction =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
prediction - label
}
.zip(model.summary.residuals.map(_.getDouble(0)))
.collect()
.foreach { case (manualResidual: Double, resultResidual: Double) =>
assert(manualResidual ~== resultResidual relTol 1E-5)
}
/*
Use the following R code to generate model training results.
predictions <- predict(fit, newx=features)
residuals <- predictions - label
> mean(residuals^2) # MSE
[1] 0.009720325
> mean(abs(residuals)) # MAD
[1] 0.07863206
> cor(predictions, label)^2# r^2
[,1]
s0 0.9998749
*/
assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5)
assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
// Objective function should be monotonically decreasing for linear regression
assert(
model.summary
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
}
test("linear regression model testset evaluation summary") {
val trainer = new LinearRegression
val model = trainer.fit(dataset)
// Evaluating on training dataset should yield results summary equal to training summary
val testSummary = model.evaluate(dataset)
assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
model.summary.residuals.select("residuals").collect()
.zip(testSummary.residuals.select("residuals").collect())
.forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
}
}