[SPARK-6885] [ML] decision tree support predict class probabilities

Decision tree support predict class probabilities.
Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593).
I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #7694 from yanboliang/spark-6885 and squashes the following commits:

08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue
2174278 [Yanbo Liang] solve merge conflicts
7e90ba8 [Yanbo Liang] fix typos
33ae183 [Yanbo Liang] fix annotation
ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place
c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again
6167fb0 [Yanbo Liang] optimize calculateImpurityStats function
fbbe2ec [Yanbo Liang] eliminate duplicated struct and code
beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode
99e8943 [Yanbo Liang] code optimization
5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats
227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator
d746ffc [Yanbo Liang] decision tree support predict class probabilities
This commit is contained in:
Yanbo Liang 2015-07-31 11:56:52 -07:00 committed by Joseph K. Bradley
parent 4011a94715
commit e8bdcdeabb
16 changed files with 229 additions and 130 deletions

View file

@ -18,12 +18,11 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
*/
@Experimental
final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("dtc"))
@ -106,8 +105,9 @@ object DecisionTreeClassifier {
@Experimental
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
override val rootNode: Node,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
require(rootNode != null,
@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
def this(rootNode: Node, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numClasses)
override protected def predict(features: Vector): Double = {
rootNode.predict(features)
rootNode.predictImpl(features).prediction
}
override protected def predictRaw(features: Vector): Vector = {
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
var i = 0
val size = dv.size
val sum = dv.values.sum
while (i < size) {
dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
i += 1
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
}
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
}
override def toString: String = {
@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel {
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
new DecisionTreeClassificationModel(uid, rootNode)
new DecisionTreeClassificationModel(uid, rootNode, -1)
}
}

View file

@ -190,7 +190,7 @@ final class GBTClassificationModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0
}

View file

@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] (
// Ignore the weights since all are 1.0 for now.
val votes = new Array[Double](numClasses)
_trees.view.foreach { tree =>
val prediction = tree.rootNode.predict(features).toInt
val prediction = tree.rootNode.predictImpl(features).prediction.toInt
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
}
Vectors.dense(votes)

View file

@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] (
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
override protected def predict(features: Vector): Double = {
rootNode.predict(features)
rootNode.predictImpl(features).prediction
}
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {

View file

@ -180,7 +180,7 @@ final class GBTRegressionModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}

View file

@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
_trees.map(_.rootNode.predict(features)).sum / numTrees
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {

View file

@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
Node => OldNode, Predict => OldPredict}
Node => OldNode, Predict => OldPredict, ImpurityStats}
/**
* :: DeveloperApi ::
@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
/** Impurity measure at this node (for training data) */
def impurity: Double
/**
* Statistics aggregated from training data at this node, used to compute prediction, impurity,
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
private[tree] def impurityStats: ImpurityCalculator
/** Recursive prediction helper method */
private[ml] def predict(features: Vector): Double = prediction
private[ml] def predictImpl(features: Vector): LeafNode
/**
* Get the number of nodes in tree below this node, including leaf nodes.
@ -75,7 +83,8 @@ private[ml] object Node {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
new LeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
@ -85,7 +94,7 @@ private[ml] object Node {
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
split = Split.fromOld(oldNode.split.get, categoricalFeatures))
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}
@ -99,11 +108,13 @@ private[ml] object Node {
@DeveloperApi
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double) extends Node {
override val impurity: Double,
override val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
override private[ml] def predict(features: Vector): Double = prediction
override private[ml] def predictImpl(features: Vector): LeafNode = this
override private[tree] def numDescendants: Int = 0
@ -115,9 +126,8 @@ final class LeafNode private[ml] (
override private[tree] def subtreeDepth: Int = 0
override private[ml] def toOld(id: Int): OldNode = {
// NOTE: We do NOT store 'prob' in the new API currently.
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
None, None, None, None)
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
impurity, isLeaf = true, None, None, None, None)
}
}
@ -139,17 +149,18 @@ final class InternalNode private[ml] (
val gain: Double,
val leftChild: Node,
val rightChild: Node,
val split: Split) extends Node {
val split: Split,
override val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
}
override private[ml] def predict(features: Vector): Double = {
override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) {
leftChild.predict(features)
leftChild.predictImpl(features)
} else {
rightChild.predict(features)
rightChild.predictImpl(features)
}
}
@ -172,9 +183,8 @@ final class InternalNode private[ml] (
override private[ml] def toOld(id: Int): OldNode = {
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ " since the old API does not support deep trees.")
// NOTE: We do NOT store 'prob' in the new API currently.
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
new OldPredict(leftChild.prediction, prob = 0.0),
@ -223,36 +233,36 @@ private object InternalNode {
*
* @param id We currently use the same indexing as the old implementation in
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
* @param predictionStats Predicted label + class probability (for classification).
* We will later modify this to store aggregate statistics for labels
* to provide all class probabilities (for classification) and maybe a
* distribution (for regression).
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
* so that we do not need to consider splitting it further.
* @param stats Old structure for storing stats about information gain, prediction, etc.
* This is legacy and will be modified in the future.
* @param stats Impurity statistics for this node.
*/
private[tree] class LearningNode(
var id: Int,
var predictionStats: OldPredict,
var impurity: Double,
var leftChild: Option[LearningNode],
var rightChild: Option[LearningNode],
var split: Option[Split],
var isLeaf: Boolean,
var stats: Option[OldInformationGainStats]) extends Serializable {
var stats: ImpurityStats) extends Serializable {
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode: Node = {
if (leftChild.nonEmpty) {
assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
new InternalNode(predictionStats.predict, impurity, stats.get.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get)
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
} else {
new LeafNode(predictionStats.predict, impurity)
if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
stats.impurityCalculator)
} else {
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}
}
}
@ -263,16 +273,14 @@ private[tree] object LearningNode {
/** Create a node with some of its fields set. */
def apply(
id: Int,
predictionStats: OldPredict,
impurity: Double,
isLeaf: Boolean): LearningNode = {
new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
isLeaf: Boolean,
stats: ImpurityStats): LearningNode = {
new LearningNode(id, None, None, None, false, stats)
}
/** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = {
new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
None, None, None, false, None)
new LearningNode(nodeIndex, None, None, None, false, null)
}
// The below indexing methods were copied from spark.mllib.tree.model.Node

View file

@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging {
parentUID match {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
}
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
}
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
}
@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging {
}
// find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) =
val (split: Split, stats: ImpurityStats) =
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats, predict))
(nodeIndex, (split, stats))
}.collectAsMap()
timer.stop("chooseSplits")
@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging {
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val (split: Split, stats: InformationGainStats, predict: Predict) =
val (split: Split, stats: ImpurityStats) =
nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
val isLeaf =
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
node.predictionStats = predict
node.isLeaf = isLeaf
node.stats = Some(stats)
node.impurity = stats.impurity
node.stats = stats
logDebug("Node = " + node)
if (!isLeaf) {
@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging {
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater(
@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging {
}
/**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
* Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
* @param stats the recycle impurity statistics for this feature's all splits,
* only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
* @return information gain and statistics for split
* @param metadata learning and dataset metadata for DecisionTree
* @return Impurity statistics for this (feature, split)
*/
private def calculateGainForSplit(
private def calculateImpurityStats(
stats: ImpurityStats,
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
metadata: DecisionTreeMetadata,
impurity: Double): InformationGainStats = {
metadata: DecisionTreeMetadata): ImpurityStats = {
val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
leftImpurityCalculator.copy.add(rightImpurityCalculator)
} else {
stats.impurityCalculator
}
val impurity: Double = if (stats == null) {
parentImpurityCalculator.calculate()
} else {
stats.impurity
}
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
val totalCount = leftCount + rightCount
// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return InformationGainStats.invalidInformationGainStats
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
val totalCount = leftCount + rightCount
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging {
// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
return InformationGainStats.invalidInformationGainStats
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
// calculate left and right predict
val leftPredict = calculatePredict(leftImpurityCalculator)
val rightPredict = calculatePredict(rightImpurityCalculator)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
leftPredict, rightPredict)
}
private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
val predict = impurityCalculator.predict
val prob = impurityCalculator.prob(predict)
new Predict(predict, prob)
}
/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
* @return predict value and impurity for current node
*/
private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = calculatePredict(parentNodeAgg)
val impurity = parentNodeAgg.calculate()
(predict, impurity)
new ImpurityStats(gain, impurity, parentImpurityCalculator,
leftImpurityCalculator, rightImpurityCalculator)
}
/**
@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging {
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
node: LearningNode): (Split, InformationGainStats, Predict) = {
node: LearningNode): (Split, ImpurityStats) = {
// Calculate prediction and impurity if current node is top node
// Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
None
var gainAndImpurityStats: ImpurityStats = if (level ==0) {
null
} else {
Some((node.predictionStats, node.impurity))
node.stats
}
// For each (feature, split), calculate the gain, and select the best (feature, split).
@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
(splitIdx, gainStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIdx, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging {
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
(splitIndex, gainStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
(splitIndex, gainStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging {
}
}.maxBy(_._2.gain)
(bestSplit, bestSplitStats, predictionAndImpurity.get._1)
(bestSplit, bestSplitStats)
}
/**

View file

@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].

View file

@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].

View file

@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
/**
* Make a deep copy of this [[ImpurityCalculator]].

View file

@ -98,7 +98,7 @@ private[tree] class VarianceAggregator()
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
require(stats.size == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +

View file

@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
/**
* :: DeveloperApi ::
@ -66,7 +67,6 @@ class InformationGainStats(
}
}
private[spark] object InformationGainStats {
/**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
@ -76,3 +76,62 @@ private[spark] object InformationGainStats {
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
/**
* :: DeveloperApi ::
* Impurity statistics for each split
* @param gain information gain value
* @param impurity current node impurity
* @param impurityCalculator impurity statistics for current node
* @param leftImpurityCalculator impurity statistics for left child node
* @param rightImpurityCalculator impurity statistics for right child node
* @param valid whether the current split satisfies minimum info gain or
* minimum number of instances per node
*/
@DeveloperApi
private[spark] class ImpurityStats(
val gain: Double,
val impurity: Double,
val impurityCalculator: ImpurityCalculator,
val leftImpurityCalculator: ImpurityCalculator,
val rightImpurityCalculator: ImpurityCalculator,
val valid: Boolean = true) extends Serializable {
override def toString: String = {
s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
s"right impurity = $rightImpurity"
}
def leftImpurity: Double = if (leftImpurityCalculator != null) {
leftImpurityCalculator.calculate()
} else {
-1.0
}
def rightImpurity: Double = if (rightImpurityCalculator != null) {
rightImpurityCalculator.calculate()
} else {
-1.0
}
}
private[spark] object ImpurityStats {
/**
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
impurityCalculator, null, null, false)
}
/**
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
* that only 'impurity' and 'impurityCalculator' are defined.
*/
def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
}
}

View file

@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
ParamsSuite.checkParams(model)
}
@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
}
test("predictRaw and predictProbability") {
val rdd = continuousDataPointsForMulticlassRDD
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
.setMaxDepth(4)
.setMaxBins(100)
val categoricalFeatures = Map(0 -> 3)
val numClasses = 3
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
val predictions = newTree.transform(newData)
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
.collect()
predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
val sum = rawPred.toArray.sum
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
}
}
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////

View file

@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
Array(1.0))
ParamsSuite.checkParams(model)
}

View file

@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
ParamsSuite.checkParams(model)
}