[SPARK-6885] [ML] decision tree support predict class probabilities
Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities
This commit is contained in:
parent
4011a94715
commit
e8bdcdeabb
|
@ -18,12 +18,11 @@
|
||||||
package org.apache.spark.ml.classification
|
package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.annotation.Experimental
|
import org.apache.spark.annotation.Experimental
|
||||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
|
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
|
||||||
import org.apache.spark.ml.tree.impl.RandomForest
|
import org.apache.spark.ml.tree.impl.RandomForest
|
||||||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||||
|
@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
final class DecisionTreeClassifier(override val uid: String)
|
final class DecisionTreeClassifier(override val uid: String)
|
||||||
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
|
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
|
||||||
with DecisionTreeParams with TreeClassifierParams {
|
with DecisionTreeParams with TreeClassifierParams {
|
||||||
|
|
||||||
def this() = this(Identifiable.randomUID("dtc"))
|
def this() = this(Identifiable.randomUID("dtc"))
|
||||||
|
@ -106,8 +105,9 @@ object DecisionTreeClassifier {
|
||||||
@Experimental
|
@Experimental
|
||||||
final class DecisionTreeClassificationModel private[ml] (
|
final class DecisionTreeClassificationModel private[ml] (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
override val rootNode: Node)
|
override val rootNode: Node,
|
||||||
extends PredictionModel[Vector, DecisionTreeClassificationModel]
|
override val numClasses: Int)
|
||||||
|
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
|
||||||
with DecisionTreeModel with Serializable {
|
with DecisionTreeModel with Serializable {
|
||||||
|
|
||||||
require(rootNode != null,
|
require(rootNode != null,
|
||||||
|
@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] (
|
||||||
* Construct a decision tree classification model.
|
* Construct a decision tree classification model.
|
||||||
* @param rootNode Root node of tree, with other nodes attached.
|
* @param rootNode Root node of tree, with other nodes attached.
|
||||||
*/
|
*/
|
||||||
def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
|
def this(rootNode: Node, numClasses: Int) =
|
||||||
|
this(Identifiable.randomUID("dtc"), rootNode, numClasses)
|
||||||
|
|
||||||
override protected def predict(features: Vector): Double = {
|
override protected def predict(features: Vector): Double = {
|
||||||
rootNode.predict(features)
|
rootNode.predictImpl(features).prediction
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def predictRaw(features: Vector): Vector = {
|
||||||
|
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||||
|
rawPrediction match {
|
||||||
|
case dv: DenseVector =>
|
||||||
|
var i = 0
|
||||||
|
val size = dv.size
|
||||||
|
val sum = dv.values.sum
|
||||||
|
while (i < size) {
|
||||||
|
dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
dv
|
||||||
|
case sv: SparseVector =>
|
||||||
|
throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
|
||||||
|
" raw2probabilityInPlace encountered SparseVector")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
|
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
|
||||||
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
|
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
|
@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel {
|
||||||
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
|
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
|
||||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
||||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
|
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
|
||||||
new DecisionTreeClassificationModel(uid, rootNode)
|
new DecisionTreeClassificationModel(uid, rootNode, -1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,7 +190,7 @@ final class GBTClassificationModel(
|
||||||
override protected def predict(features: Vector): Double = {
|
override protected def predict(features: Vector): Double = {
|
||||||
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
|
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
|
||||||
// Classifies by thresholding sum of weighted tree predictions
|
// Classifies by thresholding sum of weighted tree predictions
|
||||||
val treePredictions = _trees.map(_.rootNode.predict(features))
|
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
|
||||||
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
|
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
|
||||||
if (prediction > 0.0) 1.0 else 0.0
|
if (prediction > 0.0) 1.0 else 0.0
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] (
|
||||||
// Ignore the weights since all are 1.0 for now.
|
// Ignore the weights since all are 1.0 for now.
|
||||||
val votes = new Array[Double](numClasses)
|
val votes = new Array[Double](numClasses)
|
||||||
_trees.view.foreach { tree =>
|
_trees.view.foreach { tree =>
|
||||||
val prediction = tree.rootNode.predict(features).toInt
|
val prediction = tree.rootNode.predictImpl(features).prediction.toInt
|
||||||
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
|
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
|
||||||
}
|
}
|
||||||
Vectors.dense(votes)
|
Vectors.dense(votes)
|
||||||
|
|
|
@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] (
|
||||||
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
|
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
|
||||||
|
|
||||||
override protected def predict(features: Vector): Double = {
|
override protected def predict(features: Vector): Double = {
|
||||||
rootNode.predict(features)
|
rootNode.predictImpl(features).prediction
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
|
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
|
||||||
|
|
|
@ -180,7 +180,7 @@ final class GBTRegressionModel(
|
||||||
override protected def predict(features: Vector): Double = {
|
override protected def predict(features: Vector): Double = {
|
||||||
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
|
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
|
||||||
// Classifies by thresholding sum of weighted tree predictions
|
// Classifies by thresholding sum of weighted tree predictions
|
||||||
val treePredictions = _trees.map(_.rootNode.predict(features))
|
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
|
||||||
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
|
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] (
|
||||||
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
|
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
|
||||||
// Predict average of tree predictions.
|
// Predict average of tree predictions.
|
||||||
// Ignore the weights since all are 1.0 for now.
|
// Ignore the weights since all are 1.0 for now.
|
||||||
_trees.map(_.rootNode.predict(features)).sum / numTrees
|
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): RandomForestRegressionModel = {
|
override def copy(extra: ParamMap): RandomForestRegressionModel = {
|
||||||
|
|
|
@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.DeveloperApi
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
|
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||||
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
|
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
|
||||||
Node => OldNode, Predict => OldPredict}
|
Node => OldNode, Predict => OldPredict, ImpurityStats}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
|
@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
|
||||||
/** Impurity measure at this node (for training data) */
|
/** Impurity measure at this node (for training data) */
|
||||||
def impurity: Double
|
def impurity: Double
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Statistics aggregated from training data at this node, used to compute prediction, impurity,
|
||||||
|
* and probabilities.
|
||||||
|
* For classification, the array of class counts must be normalized to a probability distribution.
|
||||||
|
*/
|
||||||
|
private[tree] def impurityStats: ImpurityCalculator
|
||||||
|
|
||||||
/** Recursive prediction helper method */
|
/** Recursive prediction helper method */
|
||||||
private[ml] def predict(features: Vector): Double = prediction
|
private[ml] def predictImpl(features: Vector): LeafNode
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the number of nodes in tree below this node, including leaf nodes.
|
* Get the number of nodes in tree below this node, including leaf nodes.
|
||||||
|
@ -75,7 +83,8 @@ private[ml] object Node {
|
||||||
if (oldNode.isLeaf) {
|
if (oldNode.isLeaf) {
|
||||||
// TODO: Once the implementation has been moved to this API, then include sufficient
|
// TODO: Once the implementation has been moved to this API, then include sufficient
|
||||||
// statistics here.
|
// statistics here.
|
||||||
new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
|
new LeafNode(prediction = oldNode.predict.predict,
|
||||||
|
impurity = oldNode.impurity, impurityStats = null)
|
||||||
} else {
|
} else {
|
||||||
val gain = if (oldNode.stats.nonEmpty) {
|
val gain = if (oldNode.stats.nonEmpty) {
|
||||||
oldNode.stats.get.gain
|
oldNode.stats.get.gain
|
||||||
|
@ -85,7 +94,7 @@ private[ml] object Node {
|
||||||
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
|
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
|
||||||
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
|
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
|
||||||
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
|
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
|
||||||
split = Split.fromOld(oldNode.split.get, categoricalFeatures))
|
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -99,11 +108,13 @@ private[ml] object Node {
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
final class LeafNode private[ml] (
|
final class LeafNode private[ml] (
|
||||||
override val prediction: Double,
|
override val prediction: Double,
|
||||||
override val impurity: Double) extends Node {
|
override val impurity: Double,
|
||||||
|
override val impurityStats: ImpurityCalculator) extends Node {
|
||||||
|
|
||||||
override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
|
override def toString: String =
|
||||||
|
s"LeafNode(prediction = $prediction, impurity = $impurity)"
|
||||||
|
|
||||||
override private[ml] def predict(features: Vector): Double = prediction
|
override private[ml] def predictImpl(features: Vector): LeafNode = this
|
||||||
|
|
||||||
override private[tree] def numDescendants: Int = 0
|
override private[tree] def numDescendants: Int = 0
|
||||||
|
|
||||||
|
@ -115,9 +126,8 @@ final class LeafNode private[ml] (
|
||||||
override private[tree] def subtreeDepth: Int = 0
|
override private[tree] def subtreeDepth: Int = 0
|
||||||
|
|
||||||
override private[ml] def toOld(id: Int): OldNode = {
|
override private[ml] def toOld(id: Int): OldNode = {
|
||||||
// NOTE: We do NOT store 'prob' in the new API currently.
|
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
|
||||||
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
|
impurity, isLeaf = true, None, None, None, None)
|
||||||
None, None, None, None)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,17 +149,18 @@ final class InternalNode private[ml] (
|
||||||
val gain: Double,
|
val gain: Double,
|
||||||
val leftChild: Node,
|
val leftChild: Node,
|
||||||
val rightChild: Node,
|
val rightChild: Node,
|
||||||
val split: Split) extends Node {
|
val split: Split,
|
||||||
|
override val impurityStats: ImpurityCalculator) extends Node {
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
|
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
|
||||||
}
|
}
|
||||||
|
|
||||||
override private[ml] def predict(features: Vector): Double = {
|
override private[ml] def predictImpl(features: Vector): LeafNode = {
|
||||||
if (split.shouldGoLeft(features)) {
|
if (split.shouldGoLeft(features)) {
|
||||||
leftChild.predict(features)
|
leftChild.predictImpl(features)
|
||||||
} else {
|
} else {
|
||||||
rightChild.predict(features)
|
rightChild.predictImpl(features)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,9 +183,8 @@ final class InternalNode private[ml] (
|
||||||
override private[ml] def toOld(id: Int): OldNode = {
|
override private[ml] def toOld(id: Int): OldNode = {
|
||||||
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
|
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
|
||||||
+ " since the old API does not support deep trees.")
|
+ " since the old API does not support deep trees.")
|
||||||
// NOTE: We do NOT store 'prob' in the new API currently.
|
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
|
||||||
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
|
isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
|
||||||
Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
|
|
||||||
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
|
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
|
||||||
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
|
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
|
||||||
new OldPredict(leftChild.prediction, prob = 0.0),
|
new OldPredict(leftChild.prediction, prob = 0.0),
|
||||||
|
@ -223,36 +233,36 @@ private object InternalNode {
|
||||||
*
|
*
|
||||||
* @param id We currently use the same indexing as the old implementation in
|
* @param id We currently use the same indexing as the old implementation in
|
||||||
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
|
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
|
||||||
* @param predictionStats Predicted label + class probability (for classification).
|
|
||||||
* We will later modify this to store aggregate statistics for labels
|
|
||||||
* to provide all class probabilities (for classification) and maybe a
|
|
||||||
* distribution (for regression).
|
|
||||||
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
|
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
|
||||||
* so that we do not need to consider splitting it further.
|
* so that we do not need to consider splitting it further.
|
||||||
* @param stats Old structure for storing stats about information gain, prediction, etc.
|
* @param stats Impurity statistics for this node.
|
||||||
* This is legacy and will be modified in the future.
|
|
||||||
*/
|
*/
|
||||||
private[tree] class LearningNode(
|
private[tree] class LearningNode(
|
||||||
var id: Int,
|
var id: Int,
|
||||||
var predictionStats: OldPredict,
|
|
||||||
var impurity: Double,
|
|
||||||
var leftChild: Option[LearningNode],
|
var leftChild: Option[LearningNode],
|
||||||
var rightChild: Option[LearningNode],
|
var rightChild: Option[LearningNode],
|
||||||
var split: Option[Split],
|
var split: Option[Split],
|
||||||
var isLeaf: Boolean,
|
var isLeaf: Boolean,
|
||||||
var stats: Option[OldInformationGainStats]) extends Serializable {
|
var stats: ImpurityStats) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
|
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
|
||||||
*/
|
*/
|
||||||
def toNode: Node = {
|
def toNode: Node = {
|
||||||
if (leftChild.nonEmpty) {
|
if (leftChild.nonEmpty) {
|
||||||
assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
|
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
|
||||||
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
|
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
|
||||||
new InternalNode(predictionStats.predict, impurity, stats.get.gain,
|
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
|
||||||
leftChild.get.toNode, rightChild.get.toNode, split.get)
|
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
|
||||||
} else {
|
} else {
|
||||||
new LeafNode(predictionStats.predict, impurity)
|
if (stats.valid) {
|
||||||
|
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
|
||||||
|
stats.impurityCalculator)
|
||||||
|
} else {
|
||||||
|
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
|
||||||
|
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,16 +273,14 @@ private[tree] object LearningNode {
|
||||||
/** Create a node with some of its fields set. */
|
/** Create a node with some of its fields set. */
|
||||||
def apply(
|
def apply(
|
||||||
id: Int,
|
id: Int,
|
||||||
predictionStats: OldPredict,
|
isLeaf: Boolean,
|
||||||
impurity: Double,
|
stats: ImpurityStats): LearningNode = {
|
||||||
isLeaf: Boolean): LearningNode = {
|
new LearningNode(id, None, None, None, false, stats)
|
||||||
new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create an empty node with the given node index. Values must be set later on. */
|
/** Create an empty node with the given node index. Values must be set later on. */
|
||||||
def emptyNode(nodeIndex: Int): LearningNode = {
|
def emptyNode(nodeIndex: Int): LearningNode = {
|
||||||
new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
|
new LearningNode(nodeIndex, None, None, None, false, null)
|
||||||
None, None, None, false, None)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The below indexing methods were copied from spark.mllib.tree.model.Node
|
// The below indexing methods were copied from spark.mllib.tree.model.Node
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
|
||||||
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
|
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
|
||||||
TimeTracker}
|
TimeTracker}
|
||||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||||
import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
|
import org.apache.spark.mllib.tree.model.ImpurityStats
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
|
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
|
||||||
|
@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging {
|
||||||
parentUID match {
|
parentUID match {
|
||||||
case Some(uid) =>
|
case Some(uid) =>
|
||||||
if (strategy.algo == OldAlgo.Classification) {
|
if (strategy.algo == OldAlgo.Classification) {
|
||||||
topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
|
topNodes.map { rootNode =>
|
||||||
|
new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
|
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
|
||||||
}
|
}
|
||||||
case None =>
|
case None =>
|
||||||
if (strategy.algo == OldAlgo.Classification) {
|
if (strategy.algo == OldAlgo.Classification) {
|
||||||
topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
|
topNodes.map { rootNode =>
|
||||||
|
new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
|
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
|
||||||
}
|
}
|
||||||
|
@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging {
|
||||||
}
|
}
|
||||||
|
|
||||||
// find best split for each node
|
// find best split for each node
|
||||||
val (split: Split, stats: InformationGainStats, predict: Predict) =
|
val (split: Split, stats: ImpurityStats) =
|
||||||
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
|
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
|
||||||
(nodeIndex, (split, stats, predict))
|
(nodeIndex, (split, stats))
|
||||||
}.collectAsMap()
|
}.collectAsMap()
|
||||||
|
|
||||||
timer.stop("chooseSplits")
|
timer.stop("chooseSplits")
|
||||||
|
@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging {
|
||||||
val nodeIndex = node.id
|
val nodeIndex = node.id
|
||||||
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
|
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
|
||||||
val aggNodeIndex = nodeInfo.nodeIndexInGroup
|
val aggNodeIndex = nodeInfo.nodeIndexInGroup
|
||||||
val (split: Split, stats: InformationGainStats, predict: Predict) =
|
val (split: Split, stats: ImpurityStats) =
|
||||||
nodeToBestSplits(aggNodeIndex)
|
nodeToBestSplits(aggNodeIndex)
|
||||||
logDebug("best split = " + split)
|
logDebug("best split = " + split)
|
||||||
|
|
||||||
// Extract info for this node. Create children if not leaf.
|
// Extract info for this node. Create children if not leaf.
|
||||||
val isLeaf =
|
val isLeaf =
|
||||||
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
|
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
|
||||||
node.predictionStats = predict
|
|
||||||
node.isLeaf = isLeaf
|
node.isLeaf = isLeaf
|
||||||
node.stats = Some(stats)
|
node.stats = stats
|
||||||
node.impurity = stats.impurity
|
|
||||||
logDebug("Node = " + node)
|
logDebug("Node = " + node)
|
||||||
|
|
||||||
if (!isLeaf) {
|
if (!isLeaf) {
|
||||||
|
@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging {
|
||||||
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
|
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
|
||||||
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
|
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
|
||||||
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
|
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
|
||||||
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
|
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
|
||||||
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
|
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
|
||||||
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
|
rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
|
||||||
|
|
||||||
if (nodeIdCache.nonEmpty) {
|
if (nodeIdCache.nonEmpty) {
|
||||||
val nodeIndexUpdater = NodeIndexUpdater(
|
val nodeIndexUpdater = NodeIndexUpdater(
|
||||||
|
@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
|
* Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
|
||||||
|
* @param stats the recycle impurity statistics for this feature's all splits,
|
||||||
|
* only 'impurity' and 'impurityCalculator' are valid between each iteration
|
||||||
* @param leftImpurityCalculator left node aggregates for this (feature, split)
|
* @param leftImpurityCalculator left node aggregates for this (feature, split)
|
||||||
* @param rightImpurityCalculator right node aggregate for this (feature, split)
|
* @param rightImpurityCalculator right node aggregate for this (feature, split)
|
||||||
* @return information gain and statistics for split
|
* @param metadata learning and dataset metadata for DecisionTree
|
||||||
|
* @return Impurity statistics for this (feature, split)
|
||||||
*/
|
*/
|
||||||
private def calculateGainForSplit(
|
private def calculateImpurityStats(
|
||||||
|
stats: ImpurityStats,
|
||||||
leftImpurityCalculator: ImpurityCalculator,
|
leftImpurityCalculator: ImpurityCalculator,
|
||||||
rightImpurityCalculator: ImpurityCalculator,
|
rightImpurityCalculator: ImpurityCalculator,
|
||||||
metadata: DecisionTreeMetadata,
|
metadata: DecisionTreeMetadata): ImpurityStats = {
|
||||||
impurity: Double): InformationGainStats = {
|
|
||||||
|
val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
|
||||||
|
leftImpurityCalculator.copy.add(rightImpurityCalculator)
|
||||||
|
} else {
|
||||||
|
stats.impurityCalculator
|
||||||
|
}
|
||||||
|
|
||||||
|
val impurity: Double = if (stats == null) {
|
||||||
|
parentImpurityCalculator.calculate()
|
||||||
|
} else {
|
||||||
|
stats.impurity
|
||||||
|
}
|
||||||
|
|
||||||
val leftCount = leftImpurityCalculator.count
|
val leftCount = leftImpurityCalculator.count
|
||||||
val rightCount = rightImpurityCalculator.count
|
val rightCount = rightImpurityCalculator.count
|
||||||
|
|
||||||
|
val totalCount = leftCount + rightCount
|
||||||
|
|
||||||
// If left child or right child doesn't satisfy minimum instances per node,
|
// If left child or right child doesn't satisfy minimum instances per node,
|
||||||
// then this split is invalid, return invalid information gain stats.
|
// then this split is invalid, return invalid information gain stats.
|
||||||
if ((leftCount < metadata.minInstancesPerNode) ||
|
if ((leftCount < metadata.minInstancesPerNode) ||
|
||||||
(rightCount < metadata.minInstancesPerNode)) {
|
(rightCount < metadata.minInstancesPerNode)) {
|
||||||
return InformationGainStats.invalidInformationGainStats
|
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
||||||
}
|
}
|
||||||
|
|
||||||
val totalCount = leftCount + rightCount
|
|
||||||
|
|
||||||
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
|
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
|
||||||
val rightImpurity = rightImpurityCalculator.calculate()
|
val rightImpurity = rightImpurityCalculator.calculate()
|
||||||
|
|
||||||
|
@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging {
|
||||||
// if information gain doesn't satisfy minimum information gain,
|
// if information gain doesn't satisfy minimum information gain,
|
||||||
// then this split is invalid, return invalid information gain stats.
|
// then this split is invalid, return invalid information gain stats.
|
||||||
if (gain < metadata.minInfoGain) {
|
if (gain < metadata.minInfoGain) {
|
||||||
return InformationGainStats.invalidInformationGainStats
|
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate left and right predict
|
new ImpurityStats(gain, impurity, parentImpurityCalculator,
|
||||||
val leftPredict = calculatePredict(leftImpurityCalculator)
|
leftImpurityCalculator, rightImpurityCalculator)
|
||||||
val rightPredict = calculatePredict(rightImpurityCalculator)
|
|
||||||
|
|
||||||
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
|
|
||||||
leftPredict, rightPredict)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
|
|
||||||
val predict = impurityCalculator.predict
|
|
||||||
val prob = impurityCalculator.prob(predict)
|
|
||||||
new Predict(predict, prob)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate predict value for current node, given stats of any split.
|
|
||||||
* Note that this function is called only once for each node.
|
|
||||||
* @param leftImpurityCalculator left node aggregates for a split
|
|
||||||
* @param rightImpurityCalculator right node aggregates for a split
|
|
||||||
* @return predict value and impurity for current node
|
|
||||||
*/
|
|
||||||
private def calculatePredictImpurity(
|
|
||||||
leftImpurityCalculator: ImpurityCalculator,
|
|
||||||
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
|
|
||||||
val parentNodeAgg = leftImpurityCalculator.copy
|
|
||||||
parentNodeAgg.add(rightImpurityCalculator)
|
|
||||||
val predict = calculatePredict(parentNodeAgg)
|
|
||||||
val impurity = parentNodeAgg.calculate()
|
|
||||||
|
|
||||||
(predict, impurity)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging {
|
||||||
binAggregates: DTStatsAggregator,
|
binAggregates: DTStatsAggregator,
|
||||||
splits: Array[Array[Split]],
|
splits: Array[Array[Split]],
|
||||||
featuresForNode: Option[Array[Int]],
|
featuresForNode: Option[Array[Int]],
|
||||||
node: LearningNode): (Split, InformationGainStats, Predict) = {
|
node: LearningNode): (Split, ImpurityStats) = {
|
||||||
|
|
||||||
// Calculate prediction and impurity if current node is top node
|
// Calculate InformationGain and ImpurityStats if current node is top node
|
||||||
val level = LearningNode.indexToLevel(node.id)
|
val level = LearningNode.indexToLevel(node.id)
|
||||||
var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
|
var gainAndImpurityStats: ImpurityStats = if (level ==0) {
|
||||||
None
|
null
|
||||||
} else {
|
} else {
|
||||||
Some((node.predictionStats, node.impurity))
|
node.stats
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
||||||
|
@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging {
|
||||||
val rightChildStats =
|
val rightChildStats =
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||||
rightChildStats.subtract(leftChildStats)
|
rightChildStats.subtract(leftChildStats)
|
||||||
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
|
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
|
||||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
leftChildStats, rightChildStats, binAggregates.metadata)
|
||||||
val gainStats = calculateGainForSplit(leftChildStats,
|
(splitIdx, gainAndImpurityStats)
|
||||||
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
|
|
||||||
(splitIdx, gainStats)
|
|
||||||
}.maxBy(_._2.gain)
|
}.maxBy(_._2.gain)
|
||||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||||
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
||||||
|
@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging {
|
||||||
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
||||||
val rightChildStats =
|
val rightChildStats =
|
||||||
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||||
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
|
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
|
||||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
leftChildStats, rightChildStats, binAggregates.metadata)
|
||||||
val gainStats = calculateGainForSplit(leftChildStats,
|
(splitIndex, gainAndImpurityStats)
|
||||||
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
|
|
||||||
(splitIndex, gainStats)
|
|
||||||
}.maxBy(_._2.gain)
|
}.maxBy(_._2.gain)
|
||||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||||
} else {
|
} else {
|
||||||
|
@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging {
|
||||||
val rightChildStats =
|
val rightChildStats =
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
||||||
rightChildStats.subtract(leftChildStats)
|
rightChildStats.subtract(leftChildStats)
|
||||||
predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
|
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
|
||||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
leftChildStats, rightChildStats, binAggregates.metadata)
|
||||||
val gainStats = calculateGainForSplit(leftChildStats,
|
(splitIndex, gainAndImpurityStats)
|
||||||
rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
|
|
||||||
(splitIndex, gainStats)
|
|
||||||
}.maxBy(_._2.gain)
|
}.maxBy(_._2.gain)
|
||||||
val categoriesForSplit =
|
val categoriesForSplit =
|
||||||
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
||||||
|
@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging {
|
||||||
}
|
}
|
||||||
}.maxBy(_._2.gain)
|
}.maxBy(_._2.gain)
|
||||||
|
|
||||||
(bestSplit, bestSplitStats, predictionAndImpurity.get._1)
|
(bestSplit, bestSplitStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Make a deep copy of this [[ImpurityCalculator]].
|
* Make a deep copy of this [[ImpurityCalculator]].
|
||||||
|
|
|
@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int)
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Make a deep copy of this [[ImpurityCalculator]].
|
* Make a deep copy of this [[ImpurityCalculator]].
|
||||||
|
|
|
@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
|
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Make a deep copy of this [[ImpurityCalculator]].
|
* Make a deep copy of this [[ImpurityCalculator]].
|
||||||
|
|
|
@ -98,7 +98,7 @@ private[tree] class VarianceAggregator()
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||||
|
|
||||||
require(stats.size == 3,
|
require(stats.size == 3,
|
||||||
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
|
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.mllib.tree.model
|
package org.apache.spark.mllib.tree.model
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.DeveloperApi
|
||||||
|
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
|
@ -66,7 +67,6 @@ class InformationGainStats(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private[spark] object InformationGainStats {
|
private[spark] object InformationGainStats {
|
||||||
/**
|
/**
|
||||||
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
|
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
|
||||||
|
@ -76,3 +76,62 @@ private[spark] object InformationGainStats {
|
||||||
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
|
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
|
||||||
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
|
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Impurity statistics for each split
|
||||||
|
* @param gain information gain value
|
||||||
|
* @param impurity current node impurity
|
||||||
|
* @param impurityCalculator impurity statistics for current node
|
||||||
|
* @param leftImpurityCalculator impurity statistics for left child node
|
||||||
|
* @param rightImpurityCalculator impurity statistics for right child node
|
||||||
|
* @param valid whether the current split satisfies minimum info gain or
|
||||||
|
* minimum number of instances per node
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
private[spark] class ImpurityStats(
|
||||||
|
val gain: Double,
|
||||||
|
val impurity: Double,
|
||||||
|
val impurityCalculator: ImpurityCalculator,
|
||||||
|
val leftImpurityCalculator: ImpurityCalculator,
|
||||||
|
val rightImpurityCalculator: ImpurityCalculator,
|
||||||
|
val valid: Boolean = true) extends Serializable {
|
||||||
|
|
||||||
|
override def toString: String = {
|
||||||
|
s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
|
||||||
|
s"right impurity = $rightImpurity"
|
||||||
|
}
|
||||||
|
|
||||||
|
def leftImpurity: Double = if (leftImpurityCalculator != null) {
|
||||||
|
leftImpurityCalculator.calculate()
|
||||||
|
} else {
|
||||||
|
-1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
def rightImpurity: Double = if (rightImpurityCalculator != null) {
|
||||||
|
rightImpurityCalculator.calculate()
|
||||||
|
} else {
|
||||||
|
-1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] object ImpurityStats {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
|
||||||
|
* denote that current split doesn't satisfies minimum info gain or
|
||||||
|
* minimum number of instances per node.
|
||||||
|
*/
|
||||||
|
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
|
||||||
|
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
|
||||||
|
impurityCalculator, null, null, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
|
||||||
|
* that only 'impurity' and 'impurityCalculator' are defined.
|
||||||
|
*/
|
||||||
|
def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
|
||||||
|
new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.impl.TreeTests
|
import org.apache.spark.ml.impl.TreeTests
|
||||||
import org.apache.spark.ml.param.ParamsSuite
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.tree.LeafNode
|
import org.apache.spark.ml.tree.LeafNode
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
|
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
import org.apache.spark.sql.Row
|
||||||
|
|
||||||
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
|
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
||||||
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
|
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
|
||||||
ParamsSuite.checkParams(model)
|
ParamsSuite.checkParams(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
|
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("predictRaw and predictProbability") {
|
||||||
|
val rdd = continuousDataPointsForMulticlassRDD
|
||||||
|
val dt = new DecisionTreeClassifier()
|
||||||
|
.setImpurity("Gini")
|
||||||
|
.setMaxDepth(4)
|
||||||
|
.setMaxBins(100)
|
||||||
|
val categoricalFeatures = Map(0 -> 3)
|
||||||
|
val numClasses = 3
|
||||||
|
|
||||||
|
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
|
||||||
|
val newTree = dt.fit(newData)
|
||||||
|
|
||||||
|
val predictions = newTree.transform(newData)
|
||||||
|
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
|
||||||
|
.collect()
|
||||||
|
|
||||||
|
predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
||||||
|
assert(pred === rawPred.argmax,
|
||||||
|
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
|
||||||
|
val sum = rawPred.toArray.sum
|
||||||
|
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
|
||||||
|
"probability prediction mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of model save/load
|
// Tests of model save/load
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new GBTClassifier)
|
ParamsSuite.checkParams(new GBTClassifier)
|
||||||
val model = new GBTClassificationModel("gbtc",
|
val model = new GBTClassificationModel("gbtc",
|
||||||
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
|
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
|
||||||
Array(1.0))
|
Array(1.0))
|
||||||
ParamsSuite.checkParams(model)
|
ParamsSuite.checkParams(model)
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new RandomForestClassifier)
|
ParamsSuite.checkParams(new RandomForestClassifier)
|
||||||
val model = new RandomForestClassificationModel("rfc",
|
val model = new RandomForestClassificationModel("rfc",
|
||||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
|
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
|
||||||
ParamsSuite.checkParams(model)
|
ParamsSuite.checkParams(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue