[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve

Added evaluateEachIteration to allow the user to manually extract the error for each iteration of GradientBoosting. The internal optimisation can be dealt with later.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #4906 from MechCoder/spark-6025 and squashes the following commits:

67146ab [MechCoder] Minor
352001f [MechCoder] Minor
6e8aa10 [MechCoder] Made the following changes Used mapPartition instead of map Refactored computeError and unpersisted broadcast variables
bc99ac6 [MechCoder] Refactor the method and stuff
dbda033 [MechCoder] [SPARK-6025] Add helper method evaluateEachIteration to extract learning curve
This commit is contained in:
MechCoder 2015-03-20 17:14:09 -07:00 committed by Joseph K. Bradley
parent a95043b178
commit 25e271d9fb
7 changed files with 96 additions and 46 deletions

View file

@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset
The training is stopped when the improvement in the validation error is not more than a certain tolerance The training is stopped when the improvement in the validation error is not more than a certain tolerance
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error (supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
decreases initially and later increases. There might be cases in which the validation error does not change monotonically, decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration`
iterations. (which gives the error or loss per iteration) to tune the number of iterations.
### Examples ### Examples

View file

@ -47,18 +47,9 @@ object AbsoluteError extends Loss {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
} }
/** override def computeError(prediction: Double, label: Double): Double = {
* Method to calculate loss of the base learner for the gradient boosting calculation. val err = label - prediction
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging math.abs(err)
* purposes.
* @param model Ensemble model
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Mean absolute error of model on data
*/
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = model.predict(y.features) - y.label
math.abs(err)
}.mean()
} }
} }

View file

@ -50,20 +50,10 @@ object LogLoss extends Loss {
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
} }
/** override def computeError(prediction: Double, label: Double): Double = {
* Method to calculate loss of the base learner for the gradient boosting calculation. val margin = 2.0 * label * prediction
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
* purposes. 2.0 * MLUtils.log1pExp(-margin)
* @param model Ensemble model
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Mean log loss of model on data
*/
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { case point =>
val prediction = model.predict(point.features)
val margin = 2.0 * point.label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}.mean()
} }
} }

View file

@ -47,6 +47,18 @@ trait Loss extends Serializable {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Measure of model error on data * @return Measure of model error on data
*/ */
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map(point => computeError(model.predict(point.features), point.label)).mean()
}
/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param prediction Predicted label.
* @param label True label.
* @return Measure of model error on datapoint.
*/
def computeError(prediction: Double, label: Double): Double
} }

View file

@ -47,18 +47,9 @@ object SquaredError extends Loss {
2.0 * (model.predict(point.features) - point.label) 2.0 * (model.predict(point.features) - point.label)
} }
/** override def computeError(prediction: Double, label: Double): Double = {
* Method to calculate loss of the base learner for the gradient boosting calculation. val err = prediction - label
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging err * err
* purposes.
* @param model Ensemble model
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Mean squared error of model on data
*/
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = model.predict(y.features) - y.label
err * err
}.mean()
} }
} }

View file

@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
@ -108,6 +110,58 @@ class GradientBoostedTreesModel(
} }
override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param loss evaluation metric.
* @return an array with index i having the losses or errors for the ensemble
* containing the first i+1 trees
*/
def evaluateEachIteration(
data: RDD[LabeledPoint],
loss: Loss): Array[Double] = {
val sc = data.sparkContext
val remappedData = algo match {
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
case _ => data
}
val numIterations = trees.length
val evaluationArray = Array.fill(numIterations)(0.0)
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
val pred = treeWeights(0) * trees(0).predict(i.features)
val error = loss.computeError(pred, i.label)
(pred, error)
}
evaluationArray(0) = predictionAndError.values.mean()
// Avoid the model being copied across numIterations.
val broadcastTrees = sc.broadcast(trees)
val broadcastWeights = sc.broadcast(treeWeights)
(1 until numIterations).map { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = broadcastWeights.value(nTree)
iter.map {
case (point, (pred, error)) => {
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
}
}
evaluationArray(nTree) = predictionAndError.values.mean()
}
broadcastTrees.unpersist()
broadcastWeights.unpersist()
evaluationArray
}
} }
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {

View file

@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy) val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd) .runWithValidation(trainRdd, validateRdd)
assert(gbtValidate.numTrees !== numIterations) val numTrees = gbtValidate.numTrees
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset. // Test that it performs better on the validation dataset.
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = { val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) { if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
} }
} }
assert(errorWithValidation <= errorWithoutValidation) assert(errorWithValidation <= errorWithoutValidation)
// Test that results from evaluateEachIteration comply with runWithValidation.
// Note that convergenceTol is set to 0.0
val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
assert(evaluationArray.length === numIterations)
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
var i = 1
while (i < numTrees) {
assert(evaluationArray(i) <= evaluationArray(i - 1))
i += 1
}
} }
} }
} }