[SPARK-6885] [ML] decision tree support predict class probabilities

Decision tree support predict class probabilities.
Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593).
I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #7694 from yanboliang/spark-6885 and squashes the following commits:

08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue
2174278 [Yanbo Liang] solve merge conflicts
7e90ba8 [Yanbo Liang] fix typos
33ae183 [Yanbo Liang] fix annotation
ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place
c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again
6167fb0 [Yanbo Liang] optimize calculateImpurityStats function
fbbe2ec [Yanbo Liang] eliminate duplicated struct and code
beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode
99e8943 [Yanbo Liang] code optimization
5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats
227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator
d746ffc [Yanbo Liang] decision tree support predict class probabilities
This commit is contained in:
Yanbo Liang 2015-07-31 11:56:52 -07:00 committed by Joseph K. Bradley
parent 4011a94715
commit e8bdcdeabb
16 changed files with 229 additions and 130 deletions

View file

@ -18,12 +18,11 @@
package org.apache.spark.ml.classification package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
*/ */
@Experimental @Experimental
final class DecisionTreeClassifier(override val uid: String) final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams { with DecisionTreeParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("dtc")) def this() = this(Identifiable.randomUID("dtc"))
@ -106,8 +105,9 @@ object DecisionTreeClassifier {
@Experimental @Experimental
final class DecisionTreeClassificationModel private[ml] ( final class DecisionTreeClassificationModel private[ml] (
override val uid: String, override val uid: String,
override val rootNode: Node) override val rootNode: Node,
extends PredictionModel[Vector, DecisionTreeClassificationModel] override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable { with DecisionTreeModel with Serializable {
require(rootNode != null, require(rootNode != null,
@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model. * Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached. * @param rootNode Root node of tree, with other nodes attached.
*/ */
def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) def this(rootNode: Node, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numClasses)
override protected def predict(features: Vector): Double = { override protected def predict(features: Vector): Double = {
rootNode.predict(features) rootNode.predictImpl(features).prediction
}
override protected def predictRaw(features: Vector): Vector = {
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
var i = 0
val size = dv.size
val sum = dv.values.sum
while (i < size) {
dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
i += 1
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
} }
override def copy(extra: ParamMap): DecisionTreeClassificationModel = { override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
} }
override def toString: String = { override def toString: String = {
@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel {
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
new DecisionTreeClassificationModel(uid, rootNode) new DecisionTreeClassificationModel(uid, rootNode, -1)
} }
} }

View file

@ -190,7 +190,7 @@ final class GBTClassificationModel(
override protected def predict(features: Vector): Double = { override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions // Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features)) val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0 if (prediction > 0.0) 1.0 else 0.0
} }

View file

@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] (
// Ignore the weights since all are 1.0 for now. // Ignore the weights since all are 1.0 for now.
val votes = new Array[Double](numClasses) val votes = new Array[Double](numClasses)
_trees.view.foreach { tree => _trees.view.foreach { tree =>
val prediction = tree.rootNode.predict(features).toInt val prediction = tree.rootNode.predictImpl(features).prediction.toInt
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
} }
Vectors.dense(votes) Vectors.dense(votes)

View file

@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] (
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
override protected def predict(features: Vector): Double = { override protected def predict(features: Vector): Double = {
rootNode.predict(features) rootNode.predictImpl(features).prediction
} }
override def copy(extra: ParamMap): DecisionTreeRegressionModel = { override def copy(extra: ParamMap): DecisionTreeRegressionModel = {

View file

@ -180,7 +180,7 @@ final class GBTRegressionModel(
override protected def predict(features: Vector): Double = { override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions // Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features)) val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
} }

View file

@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions. // Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now. // Ignore the weights since all are 1.0 for now.
_trees.map(_.rootNode.predict(features)).sum / numTrees _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
} }
override def copy(extra: ParamMap): RandomForestRegressionModel = { override def copy(extra: ParamMap): RandomForestRegressionModel = {

View file

@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
Node => OldNode, Predict => OldPredict} Node => OldNode, Predict => OldPredict, ImpurityStats}
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
/** Impurity measure at this node (for training data) */ /** Impurity measure at this node (for training data) */
def impurity: Double def impurity: Double
/**
* Statistics aggregated from training data at this node, used to compute prediction, impurity,
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
private[tree] def impurityStats: ImpurityCalculator
/** Recursive prediction helper method */ /** Recursive prediction helper method */
private[ml] def predict(features: Vector): Double = prediction private[ml] def predictImpl(features: Vector): LeafNode
/** /**
* Get the number of nodes in tree below this node, including leaf nodes. * Get the number of nodes in tree below this node, including leaf nodes.
@ -75,7 +83,8 @@ private[ml] object Node {
if (oldNode.isLeaf) { if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient // TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here. // statistics here.
new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) new LeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else { } else {
val gain = if (oldNode.stats.nonEmpty) { val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain oldNode.stats.get.gain
@ -85,7 +94,7 @@ private[ml] object Node {
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
split = Split.fromOld(oldNode.split.get, categoricalFeatures)) split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
} }
} }
} }
@ -99,11 +108,13 @@ private[ml] object Node {
@DeveloperApi @DeveloperApi
final class LeafNode private[ml] ( final class LeafNode private[ml] (
override val prediction: Double, override val prediction: Double,
override val impurity: Double) extends Node { override val impurity: Double,
override val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
override private[ml] def predict(features: Vector): Double = prediction override private[ml] def predictImpl(features: Vector): LeafNode = this
override private[tree] def numDescendants: Int = 0 override private[tree] def numDescendants: Int = 0
@ -115,9 +126,8 @@ final class LeafNode private[ml] (
override private[tree] def subtreeDepth: Int = 0 override private[tree] def subtreeDepth: Int = 0
override private[ml] def toOld(id: Int): OldNode = { override private[ml] def toOld(id: Int): OldNode = {
// NOTE: We do NOT store 'prob' in the new API currently. new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, impurity, isLeaf = true, None, None, None, None)
None, None, None, None)
} }
} }
@ -139,17 +149,18 @@ final class InternalNode private[ml] (
val gain: Double, val gain: Double,
val leftChild: Node, val leftChild: Node,
val rightChild: Node, val rightChild: Node,
val split: Split) extends Node { val split: Split,
override val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = { override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
} }
override private[ml] def predict(features: Vector): Double = { override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) { if (split.shouldGoLeft(features)) {
leftChild.predict(features) leftChild.predictImpl(features)
} else { } else {
rightChild.predict(features) rightChild.predictImpl(features)
} }
} }
@ -172,9 +183,8 @@ final class InternalNode private[ml] (
override private[ml] def toOld(id: Int): OldNode = { override private[ml] def toOld(id: Int): OldNode = {
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ " since the old API does not support deep trees.") + " since the old API does not support deep trees.")
// NOTE: We do NOT store 'prob' in the new API currently. new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(rightChild.toOld(OldNode.rightChildIndex(id))), Some(rightChild.toOld(OldNode.rightChildIndex(id))),
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
new OldPredict(leftChild.prediction, prob = 0.0), new OldPredict(leftChild.prediction, prob = 0.0),
@ -223,36 +233,36 @@ private object InternalNode {
* *
* @param id We currently use the same indexing as the old implementation in * @param id We currently use the same indexing as the old implementation in
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later. * [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
* @param predictionStats Predicted label + class probability (for classification).
* We will later modify this to store aggregate statistics for labels
* to provide all class probabilities (for classification) and maybe a
* distribution (for regression).
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
* so that we do not need to consider splitting it further. * so that we do not need to consider splitting it further.
* @param stats Old structure for storing stats about information gain, prediction, etc. * @param stats Impurity statistics for this node.
* This is legacy and will be modified in the future.
*/ */
private[tree] class LearningNode( private[tree] class LearningNode(
var id: Int, var id: Int,
var predictionStats: OldPredict,
var impurity: Double,
var leftChild: Option[LearningNode], var leftChild: Option[LearningNode],
var rightChild: Option[LearningNode], var rightChild: Option[LearningNode],
var split: Option[Split], var split: Option[Split],
var isLeaf: Boolean, var isLeaf: Boolean,
var stats: Option[OldInformationGainStats]) extends Serializable { var stats: ImpurityStats) extends Serializable {
/** /**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/ */
def toNode: Node = { def toNode: Node = {
if (leftChild.nonEmpty) { if (leftChild.nonEmpty) {
assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
new InternalNode(predictionStats.predict, impurity, stats.get.gain, new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get) leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
} else { } else {
new LeafNode(predictionStats.predict, impurity) if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
stats.impurityCalculator)
} else {
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}
} }
} }
@ -263,16 +273,14 @@ private[tree] object LearningNode {
/** Create a node with some of its fields set. */ /** Create a node with some of its fields set. */
def apply( def apply(
id: Int, id: Int,
predictionStats: OldPredict, isLeaf: Boolean,
impurity: Double, stats: ImpurityStats): LearningNode = {
isLeaf: Boolean): LearningNode = { new LearningNode(id, None, None, None, false, stats)
new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
} }
/** Create an empty node with the given node index. Values must be set later on. */ /** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = { def emptyNode(nodeIndex: Int): LearningNode = {
new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, new LearningNode(nodeIndex, None, None, None, false, null)
None, None, None, false, None)
} }
// The below indexing methods were copied from spark.mllib.tree.model.Node // The below indexing methods were copied from spark.mllib.tree.model.Node

View file

@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
TimeTracker} TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging {
parentUID match { parentUID match {
case Some(uid) => case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) { if (strategy.algo == OldAlgo.Classification) {
topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) topNodes.map { rootNode =>
new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
}
} else { } else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
} }
case None => case None =>
if (strategy.algo == OldAlgo.Classification) { if (strategy.algo == OldAlgo.Classification) {
topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) topNodes.map { rootNode =>
new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
}
} else { } else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
} }
@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging {
} }
// find best split for each node // find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) = val (split: Split, stats: ImpurityStats) =
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats, predict)) (nodeIndex, (split, stats))
}.collectAsMap() }.collectAsMap()
timer.stop("chooseSplits") timer.stop("chooseSplits")
@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging {
val nodeIndex = node.id val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup val aggNodeIndex = nodeInfo.nodeIndexInGroup
val (split: Split, stats: InformationGainStats, predict: Predict) = val (split: Split, stats: ImpurityStats) =
nodeToBestSplits(aggNodeIndex) nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split) logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf. // Extract info for this node. Create children if not leaf.
val isLeaf = val isLeaf =
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
node.predictionStats = predict
node.isLeaf = isLeaf node.isLeaf = isLeaf
node.stats = Some(stats) node.stats = stats
node.impurity = stats.impurity
logDebug("Node = " + node) logDebug("Node = " + node)
if (!isLeaf) { if (!isLeaf) {
@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging {
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
if (nodeIdCache.nonEmpty) { if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater( val nodeIndexUpdater = NodeIndexUpdater(
@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging {
} }
/** /**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates. * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
* @param stats the recycle impurity statistics for this feature's all splits,
* only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split) * @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split)
* @return information gain and statistics for split * @param metadata learning and dataset metadata for DecisionTree
* @return Impurity statistics for this (feature, split)
*/ */
private def calculateGainForSplit( private def calculateImpurityStats(
stats: ImpurityStats,
leftImpurityCalculator: ImpurityCalculator, leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator,
metadata: DecisionTreeMetadata, metadata: DecisionTreeMetadata): ImpurityStats = {
impurity: Double): InformationGainStats = {
val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
leftImpurityCalculator.copy.add(rightImpurityCalculator)
} else {
stats.impurityCalculator
}
val impurity: Double = if (stats == null) {
parentImpurityCalculator.calculate()
} else {
stats.impurity
}
val leftCount = leftImpurityCalculator.count val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count val rightCount = rightImpurityCalculator.count
val totalCount = leftCount + rightCount
// If left child or right child doesn't satisfy minimum instances per node, // If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats. // then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) || if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) { (rightCount < metadata.minInstancesPerNode)) {
return InformationGainStats.invalidInformationGainStats return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
} }
val totalCount = leftCount + rightCount
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate() val rightImpurity = rightImpurityCalculator.calculate()
@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging {
// if information gain doesn't satisfy minimum information gain, // if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats. // then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) { if (gain < metadata.minInfoGain) {
return InformationGainStats.invalidInformationGainStats return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
} }
// calculate left and right predict new ImpurityStats(gain, impurity, parentImpurityCalculator,
val leftPredict = calculatePredict(leftImpurityCalculator) leftImpurityCalculator, rightImpurityCalculator)
val rightPredict = calculatePredict(rightImpurityCalculator)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
leftPredict, rightPredict)
}
private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
val predict = impurityCalculator.predict
val prob = impurityCalculator.prob(predict)
new Predict(predict, prob)
}
/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
* @return predict value and impurity for current node
*/
private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = calculatePredict(parentNodeAgg)
val impurity = parentNodeAgg.calculate()
(predict, impurity)
} }
/** /**
@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging {
binAggregates: DTStatsAggregator, binAggregates: DTStatsAggregator,
splits: Array[Array[Split]], splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]], featuresForNode: Option[Array[Int]],
node: LearningNode): (Split, InformationGainStats, Predict) = { node: LearningNode): (Split, ImpurityStats) = {
// Calculate prediction and impurity if current node is top node // Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id) val level = LearningNode.indexToLevel(node.id)
var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { var gainAndImpurityStats: ImpurityStats = if (level ==0) {
None null
} else { } else {
Some((node.predictionStats, node.impurity)) node.stats
} }
// For each (feature, split), calculate the gain, and select the best (feature, split). // For each (feature, split), calculate the gain, and select the best (feature, split).
@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats = val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats) rightChildStats.subtract(leftChildStats)
predictionAndImpurity = Some(predictionAndImpurity.getOrElse( gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
calculatePredictImpurity(leftChildStats, rightChildStats))) leftChildStats, rightChildStats, binAggregates.metadata)
val gainStats = calculateGainForSplit(leftChildStats, (splitIdx, gainAndImpurityStats)
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain) }.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) { } else if (binAggregates.metadata.isUnordered(featureIndex)) {
@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging {
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictionAndImpurity = Some(predictionAndImpurity.getOrElse( gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
calculatePredictImpurity(leftChildStats, rightChildStats))) leftChildStats, rightChildStats, binAggregates.metadata)
val gainStats = calculateGainForSplit(leftChildStats, (splitIndex, gainAndImpurityStats)
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain) }.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else { } else {
@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats = val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats) rightChildStats.subtract(leftChildStats)
predictionAndImpurity = Some(predictionAndImpurity.getOrElse( gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
calculatePredictImpurity(leftChildStats, rightChildStats))) leftChildStats, rightChildStats, binAggregates.metadata)
val gainStats = calculateGainForSplit(leftChildStats, (splitIndex, gainAndImpurityStats)
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain) }.maxBy(_._2.gain)
val categoriesForSplit = val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging {
} }
}.maxBy(_._2.gain) }.maxBy(_._2.gain)
(bestSplit, bestSplitStats, predictionAndImpurity.get._1) (bestSplit, bestSplitStats)
} }
/** /**

View file

@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/** /**
* Make a deep copy of this [[ImpurityCalculator]]. * Make a deep copy of this [[ImpurityCalculator]].

View file

@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int)
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/** /**
* Make a deep copy of this [[ImpurityCalculator]]. * Make a deep copy of this [[ImpurityCalculator]].

View file

@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
/** /**
* Make a deep copy of this [[ImpurityCalculator]]. * Make a deep copy of this [[ImpurityCalculator]].

View file

@ -98,7 +98,7 @@ private[tree] class VarianceAggregator()
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
require(stats.size == 3, require(stats.size == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +

View file

@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree.model package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
@ -66,7 +67,6 @@ class InformationGainStats(
} }
} }
private[spark] object InformationGainStats { private[spark] object InformationGainStats {
/** /**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
@ -76,3 +76,62 @@ private[spark] object InformationGainStats {
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
new Predict(0.0, 0.0), new Predict(0.0, 0.0)) new Predict(0.0, 0.0), new Predict(0.0, 0.0))
} }
/**
* :: DeveloperApi ::
* Impurity statistics for each split
* @param gain information gain value
* @param impurity current node impurity
* @param impurityCalculator impurity statistics for current node
* @param leftImpurityCalculator impurity statistics for left child node
* @param rightImpurityCalculator impurity statistics for right child node
* @param valid whether the current split satisfies minimum info gain or
* minimum number of instances per node
*/
@DeveloperApi
private[spark] class ImpurityStats(
val gain: Double,
val impurity: Double,
val impurityCalculator: ImpurityCalculator,
val leftImpurityCalculator: ImpurityCalculator,
val rightImpurityCalculator: ImpurityCalculator,
val valid: Boolean = true) extends Serializable {
override def toString: String = {
s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
s"right impurity = $rightImpurity"
}
def leftImpurity: Double = if (leftImpurityCalculator != null) {
leftImpurityCalculator.calculate()
} else {
-1.0
}
def rightImpurity: Double = if (rightImpurityCalculator != null) {
rightImpurityCalculator.calculate()
} else {
-1.0
}
}
private[spark] object ImpurityStats {
/**
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
impurityCalculator, null, null, false)
}
/**
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
* that only 'impurity' and 'impurityCalculator' are defined.
*/
def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
}
}

View file

@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") { test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier) ParamsSuite.checkParams(new DecisionTreeClassifier)
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
ParamsSuite.checkParams(model) ParamsSuite.checkParams(model)
} }
@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
} }
test("predictRaw and predictProbability") {
val rdd = continuousDataPointsForMulticlassRDD
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
.setMaxDepth(4)
.setMaxBins(100)
val categoricalFeatures = Map(0 -> 3)
val numClasses = 3
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
val predictions = newTree.transform(newData)
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
.collect()
predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
val sum = rawPred.toArray.sum
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load // Tests of model save/load
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////

View file

@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") { test("params") {
ParamsSuite.checkParams(new GBTClassifier) ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc", val model = new GBTClassificationModel("gbtc",
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
Array(1.0)) Array(1.0))
ParamsSuite.checkParams(model) ParamsSuite.checkParams(model)
} }

View file

@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") { test("params") {
ParamsSuite.checkParams(new RandomForestClassifier) ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc", val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
ParamsSuite.checkParams(model) ParamsSuite.checkParams(model)
} }