[SPARK-6885] [ML] decision tree support predict class probabilities
Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities
This commit is contained in:
parent
4011a94715
commit
e8bdcdeabb
|
@ -18,12 +18,11 @@
|
|||
package org.apache.spark.ml.classification
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
|
||||
import org.apache.spark.ml.tree.impl.RandomForest
|
||||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||
|
@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
|
|||
*/
|
||||
@Experimental
|
||||
final class DecisionTreeClassifier(override val uid: String)
|
||||
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
|
||||
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
|
||||
with DecisionTreeParams with TreeClassifierParams {
|
||||
|
||||
def this() = this(Identifiable.randomUID("dtc"))
|
||||
|
@ -106,8 +105,9 @@ object DecisionTreeClassifier {
|
|||
@Experimental
|
||||
final class DecisionTreeClassificationModel private[ml] (
|
||||
override val uid: String,
|
||||
override val rootNode: Node)
|
||||
extends PredictionModel[Vector, DecisionTreeClassificationModel]
|
||||
override val rootNode: Node,
|
||||
override val numClasses: Int)
|
||||
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
|
||||
with DecisionTreeModel with Serializable {
|
||||
|
||||
require(rootNode != null,
|
||||
|
@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] (
|
|||
* Construct a decision tree classification model.
|
||||
* @param rootNode Root node of tree, with other nodes attached.
|
||||
*/
|
||||
def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
|
||||
def this(rootNode: Node, numClasses: Int) =
|
||||
this(Identifiable.randomUID("dtc"), rootNode, numClasses)
|
||||
|
||||
override protected def predict(features: Vector): Double = {
|
||||
rootNode.predict(features)
|
||||
rootNode.predictImpl(features).prediction
|
||||
}
|
||||
|
||||
override protected def predictRaw(features: Vector): Vector = {
|
||||
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
|
||||
}
|
||||
|
||||
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
rawPrediction match {
|
||||
case dv: DenseVector =>
|
||||
var i = 0
|
||||
val size = dv.size
|
||||
val sum = dv.values.sum
|
||||
while (i < size) {
|
||||
dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
|
||||
i += 1
|
||||
}
|
||||
dv
|
||||
case sv: SparseVector =>
|
||||
throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
|
||||
" raw2probabilityInPlace encountered SparseVector")
|
||||
}
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
|
||||
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
|
||||
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
|
@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel {
|
|||
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
|
||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
|
||||
new DecisionTreeClassificationModel(uid, rootNode)
|
||||
new DecisionTreeClassificationModel(uid, rootNode, -1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -190,7 +190,7 @@ final class GBTClassificationModel(
|
|||
override protected def predict(features: Vector): Double = {
|
||||
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
|
||||
// Classifies by thresholding sum of weighted tree predictions
|
||||
val treePredictions = _trees.map(_.rootNode.predict(features))
|
||||
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
|
||||
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
|
||||
if (prediction > 0.0) 1.0 else 0.0
|
||||
}
|
||||
|
|
|
@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] (
|
|||
// Ignore the weights since all are 1.0 for now.
|
||||
val votes = new Array[Double](numClasses)
|
||||
_trees.view.foreach { tree =>
|
||||
val prediction = tree.rootNode.predict(features).toInt
|
||||
val prediction = tree.rootNode.predictImpl(features).prediction.toInt
|
||||
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
|
||||
}
|
||||
Vectors.dense(votes)
|
||||
|
|
|
@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] (
|
|||
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
|
||||
|
||||
override protected def predict(features: Vector): Double = {
|
||||
rootNode.predict(features)
|
||||
rootNode.predictImpl(features).prediction
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
|
||||
|
|
|
@ -180,7 +180,7 @@ final class GBTRegressionModel(
|
|||
override protected def predict(features: Vector): Double = {
|
||||
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
|
||||
// Classifies by thresholding sum of weighted tree predictions
|
||||
val treePredictions = _trees.map(_.rootNode.predict(features))
|
||||
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
|
||||
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
|
||||
}
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] (
|
|||
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
|
||||
// Predict average of tree predictions.
|
||||
// Ignore the weights since all are 1.0 for now.
|
||||
_trees.map(_.rootNode.predict(features)).sum / numTrees
|
||||
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): RandomForestRegressionModel = {
|
||||
|
|
|
@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
|
|||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
|
||||
Node => OldNode, Predict => OldPredict}
|
||||
Node => OldNode, Predict => OldPredict, ImpurityStats}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
|
@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
|
|||
/** Impurity measure at this node (for training data) */
|
||||
def impurity: Double
|
||||
|
||||
/**
|
||||
* Statistics aggregated from training data at this node, used to compute prediction, impurity,
|
||||
* and probabilities.
|
||||
* For classification, the array of class counts must be normalized to a probability distribution.
|
||||
*/
|
||||
private[tree] def impurityStats: ImpurityCalculator
|
||||
|
||||
/** Recursive prediction helper method */
|
||||
private[ml] def predict(features: Vector): Double = prediction
|
||||
private[ml] def predictImpl(features: Vector): LeafNode
|
||||
|
||||
/**
|
||||
* Get the number of nodes in tree below this node, including leaf nodes.
|
||||
|
@ -75,7 +83,8 @@ private[ml] object Node {
|
|||
if (oldNode.isLeaf) {
|
||||
// TODO: Once the implementation has been moved to this API, then include sufficient
|
||||
// statistics here.
|
||||
new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
|
||||
new LeafNode(prediction = oldNode.predict.predict,
|
||||
impurity = oldNode.impurity, impurityStats = null)
|
||||
} else {
|
||||
val gain = if (oldNode.stats.nonEmpty) {
|
||||
oldNode.stats.get.gain
|
||||
|
@ -85,7 +94,7 @@ private[ml] object Node {
|
|||
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
|
||||
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
|
||||
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
|
||||
split = Split.fromOld(oldNode.split.get, categoricalFeatures))
|
||||
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -99,11 +108,13 @@ private[ml] object Node {
|
|||
@DeveloperApi
|
||||
final class LeafNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double) extends Node {
|
||||
override val impurity: Double,
|
||||
override val impurityStats: ImpurityCalculator) extends Node {
|
||||
|
||||
override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
|
||||
override def toString: String =
|
||||
s"LeafNode(prediction = $prediction, impurity = $impurity)"
|
||||
|
||||
override private[ml] def predict(features: Vector): Double = prediction
|
||||
override private[ml] def predictImpl(features: Vector): LeafNode = this
|
||||
|
||||
override private[tree] def numDescendants: Int = 0
|
||||
|
||||
|
@ -115,9 +126,8 @@ final class LeafNode private[ml] (
|
|||
override private[tree] def subtreeDepth: Int = 0
|
||||
|
||||
override private[ml] def toOld(id: Int): OldNode = {
|
||||
// NOTE: We do NOT store 'prob' in the new API currently.
|
||||
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
|
||||
None, None, None, None)
|
||||
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
|
||||
impurity, isLeaf = true, None, None, None, None)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,17 +149,18 @@ final class InternalNode private[ml] (
|
|||
val gain: Double,
|
||||
val leftChild: Node,
|
||||
val rightChild: Node,
|
||||
val split: Split) extends Node {
|
||||
val split: Split,
|
||||
override val impurityStats: ImpurityCalculator) extends Node {
|
||||
|
||||
override def toString: String = {
|
||||
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
|
||||
}
|
||||
|
||||
override private[ml] def predict(features: Vector): Double = {
|
||||
override private[ml] def predictImpl(features: Vector): LeafNode = {
|
||||
if (split.shouldGoLeft(features)) {
|
||||
leftChild.predict(features)
|
||||
leftChild.predictImpl(features)
|
||||
} else {
|
||||
rightChild.predict(features)
|
||||
rightChild.predictImpl(features)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,9 +183,8 @@ final class InternalNode private[ml] (
|
|||
override private[ml] def toOld(id: Int): OldNode = {
|
||||
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
|
||||
+ " since the old API does not support deep trees.")
|
||||
// NOTE: We do NOT store 'prob' in the new API currently.
|
||||
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
|
||||
Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
|
||||
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
|
||||
isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
|
||||
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
|
||||
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
|
||||
new OldPredict(leftChild.prediction, prob = 0.0),
|
||||
|
@ -223,36 +233,36 @@ private object InternalNode {
|
|||
*
|
||||
* @param id We currently use the same indexing as the old implementation in
|
||||
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
|
||||
* @param predictionStats Predicted label + class probability (for classification).
|
||||
* We will later modify this to store aggregate statistics for labels
|
||||
* to provide all class probabilities (for classification) and maybe a
|
||||
* distribution (for regression).
|
||||
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
|
||||
* so that we do not need to consider splitting it further.
|
||||
* @param stats Old structure for storing stats about information gain, prediction, etc.
|
||||
* This is legacy and will be modified in the future.
|
||||
* @param stats Impurity statistics for this node.
|
||||
*/
|
||||
private[tree] class LearningNode(
|
||||
var id: Int,
|
||||
var predictionStats: OldPredict,
|
||||
var impurity: Double,
|
||||
var leftChild: Option[LearningNode],
|
||||
var rightChild: Option[LearningNode],
|
||||
var split: Option[Split],
|
||||
var isLeaf: Boolean,
|
||||
var stats: Option[OldInformationGainStats]) extends Serializable {
|
||||
var stats: ImpurityStats) extends Serializable {
|
||||
|
||||
/**
|
||||
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
|
||||
*/
|
||||
def toNode: Node = {
|
||||
if (leftChild.nonEmpty) {
|
||||
assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
|
||||
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
|
||||
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
|
||||
new InternalNode(predictionStats.predict, impurity, stats.get.gain,
|
||||
leftChild.get.toNode, rightChild.get.toNode, split.get)
|
||||
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
|
||||
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
|
||||
} else {
|
||||
new LeafNode(predictionStats.predict, impurity)
|
||||
if (stats.valid) {
|
||||
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
|
||||
stats.impurityCalculator)
|
||||
} else {
|
||||
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
|
||||
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,16 +273,14 @@ private[tree] object LearningNode {
|
|||
/** Create a node with some of its fields set. */
|
||||
def apply(
|
||||
id: Int,
|
||||
predictionStats: OldPredict,
|
||||
impurity: Double,
|
||||
isLeaf: Boolean): LearningNode = {
|
||||
new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
|
||||
isLeaf: Boolean,
|
||||
stats: ImpurityStats): LearningNode = {
|
||||
new LearningNode(id, None, None, None, false, stats)
|
||||
}
|
||||
|
||||
/** Create an empty node with the given node index. Values must be set later on. */
|
||||
def emptyNode(nodeIndex: Int): LearningNode = {
|
||||
new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
|
||||
None, None, None, false, None)
|
||||
new LearningNode(nodeIndex, None, None, None, false, null)
|
||||
}
|
||||
|
||||
// The below indexing methods were copied from spark.mllib.tree.model.Node
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
|
|||
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
|
||||
TimeTracker}
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
|
||||
import org.apache.spark.mllib.tree.model.ImpurityStats
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
|
||||
|
@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging {
|
|||
parentUID match {
|
||||
case Some(uid) =>
|
||||
if (strategy.algo == OldAlgo.Classification) {
|
||||
topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
|
||||
topNodes.map { rootNode =>
|
||||
new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
|
||||
}
|
||||
} else {
|
||||
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
|
||||
}
|
||||
case None =>
|
||||
if (strategy.algo == OldAlgo.Classification) {
|
||||
topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
|
||||
topNodes.map { rootNode =>
|
||||
new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
|
||||
}
|
||||
} else {
|
||||
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
|
||||
}
|
||||
|
@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging {
|
|||
}
|
||||
|
||||
// find best split for each node
|
||||
val (split: Split, stats: InformationGainStats, predict: Predict) =
|
||||
val (split: Split, stats: ImpurityStats) =
|
||||
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
|
||||
(nodeIndex, (split, stats, predict))
|
||||
(nodeIndex, (split, stats))
|
||||
}.collectAsMap()
|
||||
|
||||
timer.stop("chooseSplits")
|
||||
|
@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging {
|
|||
val nodeIndex = node.id
|
||||
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
|
||||
val aggNodeIndex = nodeInfo.nodeIndexInGroup
|
||||
val (split: Split, stats: InformationGainStats, predict: Predict) =
|
||||
val (split: Split, stats: ImpurityStats) =
|
||||
nodeToBestSplits(aggNodeIndex)
|
||||
logDebug("best split = " + split)
|
||||
|
||||
// Extract info for this node. Create children if not leaf.
|
||||
val isLeaf =
|
||||
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
|
||||
node.predictionStats = predict
|
||||
node.isLeaf = isLeaf
|
||||
node.stats = Some(stats)
|
||||
node.impurity = stats.impurity
|
||||
node.stats = stats
|
||||
logDebug("Node = " + node)
|
||||
|
||||
if (!isLeaf) {
|
||||
|
@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging {
|
|||
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
|
||||
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
|
||||
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
|
||||
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
|
||||
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
|
||||
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
|
||||
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
|
||||
rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
|
||||
|
||||
if (nodeIdCache.nonEmpty) {
|
||||
val nodeIndexUpdater = NodeIndexUpdater(
|
||||
|
@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging {
|
|||
}
|
||||
|
||||
/**
|
||||
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
|
||||
* Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
|
||||
* @param stats the recycle impurity statistics for this feature's all splits,
|
||||
* only 'impurity' and 'impurityCalculator' are valid between each iteration
|
||||
* @param leftImpurityCalculator left node aggregates for this (feature, split)
|
||||
* @param rightImpurityCalculator right node aggregate for this (feature, split)
|
||||
* @return information gain and statistics for split
|
||||
* @param metadata learning and dataset metadata for DecisionTree
|
||||
* @return Impurity statistics for this (feature, split)
|
||||
*/
|
||||
private def calculateGainForSplit(
|
||||
private def calculateImpurityStats(
|
||||
stats: ImpurityStats,
|
||||
leftImpurityCalculator: ImpurityCalculator,
|
||||
rightImpurityCalculator: ImpurityCalculator,
|
||||
metadata: DecisionTreeMetadata,
|
||||
impurity: Double): InformationGainStats = {
|
||||
metadata: DecisionTreeMetadata): ImpurityStats = {
|
||||
|
||||
val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
|
||||
leftImpurityCalculator.copy.add(rightImpurityCalculator)
|
||||
} else {
|
||||
stats.impurityCalculator
|
||||
}
|
||||
|
||||
val impurity: Double = if (stats == null) {
|
||||
parentImpurityCalculator.calculate()
|
||||
} else {
|
||||
stats.impurity
|
||||
}
|
||||
|
||||
val leftCount = leftImpurityCalculator.count
|
||||
val rightCount = rightImpurityCalculator.count
|
||||
|
||||
val totalCount = leftCount + rightCount
|
||||
|
||||
// If left child or right child doesn't satisfy minimum instances per node,
|
||||
// then this split is invalid, return invalid information gain stats.
|
||||
if ((leftCount < metadata.minInstancesPerNode) ||
|
||||
(rightCount < metadata.minInstancesPerNode)) {
|
||||
return InformationGainStats.invalidInformationGainStats
|
||||
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
||||
}
|
||||
|
||||
val totalCount = leftCount + rightCount
|
||||
|
||||
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
|
||||
val rightImpurity = rightImpurityCalculator.calculate()
|
||||
|
||||
|
@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging {
|
|||
// if information gain doesn't satisfy minimum information gain,
|
||||
// then this split is invalid, return invalid information gain stats.
|
||||
if (gain < metadata.minInfoGain) {
|
||||
return InformationGainStats.invalidInformationGainStats
|
||||
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
||||
}
|
||||
|
||||
// calculate left and right predict
|
||||
val leftPredict = calculatePredict(leftImpurityCalculator)
|
||||
val rightPredict = calculatePredict(rightImpurityCalculator)
|
||||
|
||||
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
|
||||
leftPredict, rightPredict)
|
||||
}
|
||||
|
||||
private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
|
||||
val predict = impurityCalculator.predict
|
||||
val prob = impurityCalculator.prob(predict)
|
||||
new Predict(predict, prob)
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate predict value for current node, given stats of any split.
|
||||
* Note that this function is called only once for each node.
|
||||
* @param leftImpurityCalculator left node aggregates for a split
|
||||
* @param rightImpurityCalculator right node aggregates for a split
|
||||
* @return predict value and impurity for current node
|
||||
*/
|
||||
private def calculatePredictImpurity(
|
||||
leftImpurityCalculator: ImpurityCalculator,
|
||||
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
|
||||
val parentNodeAgg = leftImpurityCalculator.copy
|
||||
parentNodeAgg.add(rightImpurityCalculator)
|
||||
val predict = calculatePredict(parentNodeAgg)
|
||||
val impurity = parentNodeAgg.calculate()
|
||||
|
||||
(predict, impurity)
|
||||
new ImpurityStats(gain, impurity, parentImpurityCalculator,
|
||||
leftImpurityCalculator, rightImpurityCalculator)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging {
|
|||
binAggregates: DTStatsAggregator,
|
||||
splits: Array[Array[Split]],
|
||||
featuresForNode: Option[Array[Int]],
|
||||
node: LearningNode): (Split, InformationGainStats, Predict) = {
|
||||
node: LearningNode): (Split, ImpurityStats) = {
|
||||
|
||||
// Calculate prediction and impurity if current node is top node
|
||||
// Calculate InformationGain and ImpurityStats if current node is top node
|
||||
val level = LearningNode.indexToLevel(node.id)
|
||||
var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
|
||||
None
|
||||
var gainAndImpurityStats: ImpurityStats = if (level ==0) {
|
||||
null
|
||||
} else {
|
||||
Some((node.predictionStats, node.impurity))
|
||||
node.stats
|
||||
}
|
||||
|
||||
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
||||
|
@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging {
|
|||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
|
||||
(splitIdx, gainStats)
|
||||
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
|
||||
leftChildStats, rightChildStats, binAggregates.metadata)
|
||||
(splitIdx, gainAndImpurityStats)
|
||||
}.maxBy(_._2.gain)
|
||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
||||
|
@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging {
|
|||
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
|
||||
(splitIndex, gainStats)
|
||||
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
|
||||
leftChildStats, rightChildStats, binAggregates.metadata)
|
||||
(splitIndex, gainAndImpurityStats)
|
||||
}.maxBy(_._2.gain)
|
||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||
} else {
|
||||
|
@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging {
|
|||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
|
||||
(splitIndex, gainStats)
|
||||
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
|
||||
leftChildStats, rightChildStats, binAggregates.metadata)
|
||||
(splitIndex, gainAndImpurityStats)
|
||||
}.maxBy(_._2.gain)
|
||||
val categoriesForSplit =
|
||||
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
||||
|
@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging {
|
|||
}
|
||||
}.maxBy(_._2.gain)
|
||||
|
||||
(bestSplit, bestSplitStats, predictionAndImpurity.get._1)
|
||||
(bestSplit, bestSplitStats)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
|
||||
/**
|
||||
* Make a deep copy of this [[ImpurityCalculator]].
|
||||
|
|
|
@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int)
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
|
||||
/**
|
||||
* Make a deep copy of this [[ImpurityCalculator]].
|
||||
|
|
|
@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
|
||||
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
|
||||
|
||||
/**
|
||||
* Make a deep copy of this [[ImpurityCalculator]].
|
||||
|
|
|
@ -98,7 +98,7 @@ private[tree] class VarianceAggregator()
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
|
||||
require(stats.size == 3,
|
||||
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
|
@ -66,7 +67,6 @@ class InformationGainStats(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
private[spark] object InformationGainStats {
|
||||
/**
|
||||
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
|
||||
|
@ -76,3 +76,62 @@ private[spark] object InformationGainStats {
|
|||
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
|
||||
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Impurity statistics for each split
|
||||
* @param gain information gain value
|
||||
* @param impurity current node impurity
|
||||
* @param impurityCalculator impurity statistics for current node
|
||||
* @param leftImpurityCalculator impurity statistics for left child node
|
||||
* @param rightImpurityCalculator impurity statistics for right child node
|
||||
* @param valid whether the current split satisfies minimum info gain or
|
||||
* minimum number of instances per node
|
||||
*/
|
||||
@DeveloperApi
|
||||
private[spark] class ImpurityStats(
|
||||
val gain: Double,
|
||||
val impurity: Double,
|
||||
val impurityCalculator: ImpurityCalculator,
|
||||
val leftImpurityCalculator: ImpurityCalculator,
|
||||
val rightImpurityCalculator: ImpurityCalculator,
|
||||
val valid: Boolean = true) extends Serializable {
|
||||
|
||||
override def toString: String = {
|
||||
s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
|
||||
s"right impurity = $rightImpurity"
|
||||
}
|
||||
|
||||
def leftImpurity: Double = if (leftImpurityCalculator != null) {
|
||||
leftImpurityCalculator.calculate()
|
||||
} else {
|
||||
-1.0
|
||||
}
|
||||
|
||||
def rightImpurity: Double = if (rightImpurityCalculator != null) {
|
||||
rightImpurityCalculator.calculate()
|
||||
} else {
|
||||
-1.0
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object ImpurityStats {
|
||||
|
||||
/**
|
||||
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
|
||||
* denote that current split doesn't satisfies minimum info gain or
|
||||
* minimum number of instances per node.
|
||||
*/
|
||||
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
|
||||
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
|
||||
impurityCalculator, null, null, false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
|
||||
* that only 'impurity' and 'impurityCalculator' are defined.
|
||||
*/
|
||||
def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
|
||||
new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.ml.impl.TreeTests
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.Row
|
||||
|
||||
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
|
@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
||||
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
|
||||
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
|
@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
|
||||
}
|
||||
|
||||
test("predictRaw and predictProbability") {
|
||||
val rdd = continuousDataPointsForMulticlassRDD
|
||||
val dt = new DecisionTreeClassifier()
|
||||
.setImpurity("Gini")
|
||||
.setMaxDepth(4)
|
||||
.setMaxBins(100)
|
||||
val categoricalFeatures = Map(0 -> 3)
|
||||
val numClasses = 3
|
||||
|
||||
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
|
||||
val newTree = dt.fit(newData)
|
||||
|
||||
val predictions = newTree.transform(newData)
|
||||
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
|
||||
.collect()
|
||||
|
||||
predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
||||
assert(pred === rawPred.argmax,
|
||||
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
|
||||
val sum = rawPred.toArray.sum
|
||||
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
|
||||
"probability prediction mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
test("params") {
|
||||
ParamsSuite.checkParams(new GBTClassifier)
|
||||
val model = new GBTClassificationModel("gbtc",
|
||||
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
|
||||
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
|
||||
Array(1.0))
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
test("params") {
|
||||
ParamsSuite.checkParams(new RandomForestClassifier)
|
||||
val model = new RandomForestClassificationModel("rfc",
|
||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
|
||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue