[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor
Added featureImportance to RandomForestClassifier and Regressor.
This follows the scikit-learn implementation here: [a95203b249/sklearn/tree/_tree.pyx (L3341)
]
CC: yanboliang Would you mind taking a look? Thanks!
Author: Joseph K. Bradley <joseph@databricks.com>
Author: Feynman Liang <fliang@databricks.com>
Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits:
72a167a [Joseph K. Bradley] fixed unit test
86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map
5aa74f0 [Joseph K. Bradley] finally fixed unit test for real
33df5db [Joseph K. Bradley] fix unit test
42a2d3b [Joseph K. Bradley] fix unit test
fe94e72 [Joseph K. Bradley] modified feature importance unit tests
cc693ee [Feynman Liang] Add classifier tests
79a6f87 [Feynman Liang] Compare dense vectors in test
21d01fc [Feynman Liang] Added failing SKLearn test
ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests
This commit is contained in:
parent
703e44bff1
commit
ff9169a002
|
@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String)
|
|||
val trees =
|
||||
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
|
||||
.map(_.asInstanceOf[DecisionTreeClassificationModel])
|
||||
new RandomForestClassificationModel(trees, numClasses)
|
||||
val numFeatures = oldDataset.first().features.size
|
||||
new RandomForestClassificationModel(trees, numFeatures, numClasses)
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
|
||||
|
@ -118,11 +119,13 @@ object RandomForestClassifier {
|
|||
* features.
|
||||
* @param _trees Decision trees in the ensemble.
|
||||
* Warning: These have null parents.
|
||||
* @param numFeatures Number of features used by this model
|
||||
*/
|
||||
@Experimental
|
||||
final class RandomForestClassificationModel private[ml] (
|
||||
override val uid: String,
|
||||
private val _trees: Array[DecisionTreeClassificationModel],
|
||||
val numFeatures: Int,
|
||||
override val numClasses: Int)
|
||||
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
|
||||
with TreeEnsembleModel with Serializable {
|
||||
|
@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
|
|||
* Construct a random forest classification model, with all trees weighted equally.
|
||||
* @param trees Component trees
|
||||
*/
|
||||
def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
|
||||
this(Identifiable.randomUID("rfc"), trees, numClasses)
|
||||
def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) =
|
||||
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
|
||||
|
||||
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
|
||||
|
||||
|
@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] (
|
|||
}
|
||||
|
||||
override def copy(extra: ParamMap): RandomForestClassificationModel = {
|
||||
copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
|
||||
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
s"RandomForestClassificationModel with $numTrees trees"
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate of the importance of each feature.
|
||||
*
|
||||
* This generalizes the idea of "Gini" importance to other losses,
|
||||
* following the explanation of Gini importance from "Random Forests" documentation
|
||||
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
||||
*
|
||||
* This feature importance is calculated as follows:
|
||||
* - Average over trees:
|
||||
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
||||
* where gain is scaled by the number of instances passing through node
|
||||
* - Normalize importances for tree based on total number of training instances used
|
||||
* to build tree.
|
||||
* - Normalize feature importance vector to sum to 1.
|
||||
*/
|
||||
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
|
||||
|
||||
/** (private[ml]) Convert to a model in the old API */
|
||||
private[ml] def toOld: OldRandomForestModel = {
|
||||
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
|
||||
|
@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel {
|
|||
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
|
||||
}
|
||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
|
||||
new RandomForestClassificationModel(uid, newTrees, numClasses)
|
||||
new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String)
|
|||
val trees =
|
||||
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
|
||||
.map(_.asInstanceOf[DecisionTreeRegressionModel])
|
||||
new RandomForestRegressionModel(trees)
|
||||
val numFeatures = oldDataset.first().features.size
|
||||
new RandomForestRegressionModel(trees, numFeatures)
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
|
||||
|
@ -108,11 +109,13 @@ object RandomForestRegressor {
|
|||
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
|
||||
* It supports both continuous and categorical features.
|
||||
* @param _trees Decision trees in the ensemble.
|
||||
* @param numFeatures Number of features used by this model
|
||||
*/
|
||||
@Experimental
|
||||
final class RandomForestRegressionModel private[ml] (
|
||||
override val uid: String,
|
||||
private val _trees: Array[DecisionTreeRegressionModel])
|
||||
private val _trees: Array[DecisionTreeRegressionModel],
|
||||
val numFeatures: Int)
|
||||
extends PredictionModel[Vector, RandomForestRegressionModel]
|
||||
with TreeEnsembleModel with Serializable {
|
||||
|
||||
|
@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] (
|
|||
* Construct a random forest regression model, with all trees weighted equally.
|
||||
* @param trees Component trees
|
||||
*/
|
||||
def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
|
||||
def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
|
||||
this(Identifiable.randomUID("rfr"), trees, numFeatures)
|
||||
|
||||
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
|
||||
|
||||
|
@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] (
|
|||
}
|
||||
|
||||
override def copy(extra: ParamMap): RandomForestRegressionModel = {
|
||||
copyValues(new RandomForestRegressionModel(uid, _trees), extra)
|
||||
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra)
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
s"RandomForestRegressionModel with $numTrees trees"
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate of the importance of each feature.
|
||||
*
|
||||
* This generalizes the idea of "Gini" importance to other losses,
|
||||
* following the explanation of Gini importance from "Random Forests" documentation
|
||||
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
||||
*
|
||||
* This feature importance is calculated as follows:
|
||||
* - Average over trees:
|
||||
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
||||
* where gain is scaled by the number of instances passing through node
|
||||
* - Normalize importances for tree based on total number of training instances used
|
||||
* to build tree.
|
||||
* - Normalize feature importance vector to sum to 1.
|
||||
*/
|
||||
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
|
||||
|
||||
/** (private[ml]) Convert to a model in the old API */
|
||||
private[ml] def toOld: OldRandomForestModel = {
|
||||
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
|
||||
|
@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
|
|||
// parent for each tree is null since there is no good way to set this.
|
||||
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
||||
}
|
||||
new RandomForestRegressionModel(parent.uid, newTrees)
|
||||
new RandomForestRegressionModel(parent.uid, newTrees, -1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable {
|
|||
* and probabilities.
|
||||
* For classification, the array of class counts must be normalized to a probability distribution.
|
||||
*/
|
||||
private[tree] def impurityStats: ImpurityCalculator
|
||||
private[ml] def impurityStats: ImpurityCalculator
|
||||
|
||||
/** Recursive prediction helper method */
|
||||
private[ml] def predictImpl(features: Vector): LeafNode
|
||||
|
@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable {
|
|||
* @param id Node ID using old format IDs
|
||||
*/
|
||||
private[ml] def toOld(id: Int): OldNode
|
||||
|
||||
/**
|
||||
* Trace down the tree, and return the largest feature index used in any split.
|
||||
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
|
||||
*/
|
||||
private[ml] def maxSplitFeatureIndex(): Int
|
||||
}
|
||||
|
||||
private[ml] object Node {
|
||||
|
@ -109,7 +115,7 @@ private[ml] object Node {
|
|||
final class LeafNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
override val impurityStats: ImpurityCalculator) extends Node {
|
||||
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
|
||||
|
||||
override def toString: String =
|
||||
s"LeafNode(prediction = $prediction, impurity = $impurity)"
|
||||
|
@ -129,6 +135,8 @@ final class LeafNode private[ml] (
|
|||
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
|
||||
impurity, isLeaf = true, None, None, None, None)
|
||||
}
|
||||
|
||||
override private[ml] def maxSplitFeatureIndex(): Int = -1
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -150,7 +158,7 @@ final class InternalNode private[ml] (
|
|||
val leftChild: Node,
|
||||
val rightChild: Node,
|
||||
val split: Split,
|
||||
override val impurityStats: ImpurityCalculator) extends Node {
|
||||
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
|
||||
|
||||
override def toString: String = {
|
||||
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
|
||||
|
@ -190,6 +198,11 @@ final class InternalNode private[ml] (
|
|||
new OldPredict(leftChild.prediction, prob = 0.0),
|
||||
new OldPredict(rightChild.prediction, prob = 0.0))))
|
||||
}
|
||||
|
||||
override private[ml] def maxSplitFeatureIndex(): Int = {
|
||||
math.max(split.featureIndex,
|
||||
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
|
||||
}
|
||||
}
|
||||
|
||||
private object InternalNode {
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.spark.Logging
|
|||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.mllib.linalg.{Vectors, Vector}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
|
||||
|
@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
|||
import org.apache.spark.mllib.tree.model.ImpurityStats
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.collection.OpenHashMap
|
||||
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
|
||||
|
||||
|
||||
|
@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a Random Forest model, compute the importance of each feature.
|
||||
* This generalizes the idea of "Gini" importance to other losses,
|
||||
* following the explanation of Gini importance from "Random Forests" documentation
|
||||
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
||||
*
|
||||
* This feature importance is calculated as follows:
|
||||
* - Average over trees:
|
||||
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
||||
* where gain is scaled by the number of instances passing through node
|
||||
* - Normalize importances for tree based on total number of training instances used
|
||||
* to build tree.
|
||||
* - Normalize feature importance vector to sum to 1.
|
||||
*
|
||||
* Note: This should not be used with Gradient-Boosted Trees. It only makes sense for
|
||||
* independently trained trees.
|
||||
* @param trees Unweighted forest of trees
|
||||
* @param numFeatures Number of features in model (even if not all are explicitly used by
|
||||
* the model).
|
||||
* If -1, then numFeatures is set based on the max feature index in all trees.
|
||||
* @return Feature importance values, of length numFeatures.
|
||||
*/
|
||||
private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
|
||||
val totalImportances = new OpenHashMap[Int, Double]()
|
||||
trees.foreach { tree =>
|
||||
// Aggregate feature importance vector for this tree
|
||||
val importances = new OpenHashMap[Int, Double]()
|
||||
computeFeatureImportance(tree.rootNode, importances)
|
||||
// Normalize importance vector for this tree, and add it to total.
|
||||
// TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
|
||||
val treeNorm = importances.map(_._2).sum
|
||||
if (treeNorm != 0) {
|
||||
importances.foreach { case (idx, impt) =>
|
||||
val normImpt = impt / treeNorm
|
||||
totalImportances.changeValue(idx, normImpt, _ + normImpt)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Normalize importances
|
||||
normalizeMapValues(totalImportances)
|
||||
// Construct vector
|
||||
val d = if (numFeatures != -1) {
|
||||
numFeatures
|
||||
} else {
|
||||
// Find max feature index used in trees
|
||||
val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
|
||||
maxFeatureIndex + 1
|
||||
}
|
||||
if (d == 0) {
|
||||
assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
|
||||
s" importance: No splits in forest, but some non-zero importances.")
|
||||
}
|
||||
val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
|
||||
Vectors.sparse(d, indices.toArray, values.toArray)
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursive method for computing feature importances for one tree.
|
||||
* This walks down the tree, adding to the importance of 1 feature at each node.
|
||||
* @param node Current node in recursion
|
||||
* @param importances Aggregate feature importances, modified by this method
|
||||
*/
|
||||
private[impl] def computeFeatureImportance(
|
||||
node: Node,
|
||||
importances: OpenHashMap[Int, Double]): Unit = {
|
||||
node match {
|
||||
case n: InternalNode =>
|
||||
val feature = n.split.featureIndex
|
||||
val scaledGain = n.gain * n.impurityStats.count
|
||||
importances.changeValue(feature, scaledGain, _ + scaledGain)
|
||||
computeFeatureImportance(n.leftChild, importances)
|
||||
computeFeatureImportance(n.rightChild, importances)
|
||||
case n: LeafNode =>
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize the values of this map to sum to 1, in place.
|
||||
* If all values are 0, this method does nothing.
|
||||
* @param map Map with non-negative values.
|
||||
*/
|
||||
private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
|
||||
val total = map.map(_._2).sum
|
||||
if (total != 0) {
|
||||
val keys = map.iterator.map(_._1).toArray
|
||||
keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel {
|
|||
val header = toString + "\n"
|
||||
header + rootNode.subtreeToString(2)
|
||||
}
|
||||
|
||||
/**
|
||||
* Trace down the tree, and return the largest feature index used in any split.
|
||||
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
|
||||
*/
|
||||
private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
|
|||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.ml.impl.TreeTests;
|
||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||
import org.apache.spark.mllib.linalg.Vector;
|
||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
|
||||
|
@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
|
|||
model.toDebugString();
|
||||
model.trees();
|
||||
model.treeWeights();
|
||||
Vector importances = model.featureImportances();
|
||||
|
||||
/*
|
||||
// TODO: Add test once save/load are implemented. SPARK-6725
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
|
|||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
|
||||
import org.apache.spark.ml.impl.TreeTests;
|
||||
import org.apache.spark.mllib.linalg.Vector;
|
||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
|
||||
|
@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
|
|||
model.toDebugString();
|
||||
model.trees();
|
||||
model.treeWeights();
|
||||
Vector importances = model.featureImportances();
|
||||
|
||||
/*
|
||||
// TODO: Add test once save/load are implemented. SPARK-6725
|
||||
|
|
|
@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
test("params") {
|
||||
ParamsSuite.checkParams(new RandomForestClassifier)
|
||||
val model = new RandomForestClassificationModel("rfc",
|
||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
|
||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
|
@ -149,6 +149,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of feature importance
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
test("Feature importance with toy data") {
|
||||
val numClasses = 2
|
||||
val rf = new RandomForestClassifier()
|
||||
.setImpurity("Gini")
|
||||
.setMaxDepth(3)
|
||||
.setNumTrees(3)
|
||||
.setFeatureSubsetStrategy("all")
|
||||
.setSubsamplingRate(1.0)
|
||||
.setSeed(123)
|
||||
|
||||
// In this data, feature 1 is very important.
|
||||
val data: RDD[LabeledPoint] = sc.parallelize(Seq(
|
||||
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
|
||||
new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
|
||||
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
|
||||
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
|
||||
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
|
||||
))
|
||||
val categoricalFeatures = Map.empty[Int, Int]
|
||||
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
||||
|
||||
val importances = rf.fit(df).featureImportances
|
||||
val mostImportantFeature = importances.argmax
|
||||
assert(mostImportantFeature === 1)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
"checkEqual failed since the two tree ensembles were not identical")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method for constructing a tree for testing.
|
||||
* Given left, right children, construct a parent node.
|
||||
* @param split Split for parent node
|
||||
* @return Parent node with children attached
|
||||
*/
|
||||
def buildParentNode(left: Node, right: Node, split: Split): Node = {
|
||||
val leftImp = left.impurityStats
|
||||
val rightImp = right.impurityStats
|
||||
val parentImp = leftImp.copy.add(rightImp)
|
||||
val leftWeight = leftImp.count / parentImp.count.toDouble
|
||||
val rightWeight = rightImp.count / parentImp.count.toDouble
|
||||
val gain = parentImp.calculate() -
|
||||
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
|
||||
val pred = parentImp.predict
|
||||
new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.impl.TreeTests
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||
|
@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
|
||||
/**
|
||||
* Test suite for [[RandomForestRegressor]].
|
||||
*/
|
||||
|
@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
|
|||
regressionTestWithContinuousFeatures(rf)
|
||||
}
|
||||
|
||||
test("Feature importance with toy data") {
|
||||
val rf = new RandomForestRegressor()
|
||||
.setImpurity("variance")
|
||||
.setMaxDepth(3)
|
||||
.setNumTrees(3)
|
||||
.setFeatureSubsetStrategy("all")
|
||||
.setSubsamplingRate(1.0)
|
||||
.setSeed(123)
|
||||
|
||||
// In this data, feature 1 is very important.
|
||||
val data: RDD[LabeledPoint] = sc.parallelize(Seq(
|
||||
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
|
||||
new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
|
||||
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
|
||||
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
|
||||
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
|
||||
))
|
||||
val categoricalFeatures = Map.empty[Int, Int]
|
||||
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
|
||||
|
||||
val importances = rf.fit(df).featureImportances
|
||||
val mostImportantFeature = importances.argmax
|
||||
assert(mostImportantFeature === 1)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.tree.impl
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||
import org.apache.spark.ml.impl.TreeTests
|
||||
import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.tree.impurity.GiniCalculator
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.util.collection.OpenHashMap
|
||||
|
||||
/**
|
||||
* Test suite for [[RandomForest]].
|
||||
*/
|
||||
class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
import RandomForestSuite.mapToVec
|
||||
|
||||
test("computeFeatureImportance, featureImportances") {
|
||||
/* Build tree for testing, with this structure:
|
||||
grandParent
|
||||
left2 parent
|
||||
left right
|
||||
*/
|
||||
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
|
||||
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
|
||||
|
||||
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
|
||||
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
|
||||
|
||||
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
|
||||
val parentImp = parent.impurityStats
|
||||
|
||||
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
|
||||
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
|
||||
|
||||
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
|
||||
val grandImp = grandParent.impurityStats
|
||||
|
||||
// Test feature importance computed at different subtrees.
|
||||
def testNode(node: Node, expected: Map[Int, Double]): Unit = {
|
||||
val map = new OpenHashMap[Int, Double]()
|
||||
RandomForest.computeFeatureImportance(node, map)
|
||||
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
|
||||
}
|
||||
|
||||
// Leaf node
|
||||
testNode(left, Map.empty[Int, Double])
|
||||
|
||||
// Internal node with 2 leaf children
|
||||
val feature0importance = parentImp.calculate() * parentImp.count -
|
||||
(leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count)
|
||||
testNode(parent, Map(0 -> feature0importance))
|
||||
|
||||
// Full tree
|
||||
val feature1importance = grandImp.calculate() * grandImp.count -
|
||||
(left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count)
|
||||
testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance))
|
||||
|
||||
// Forest consisting of (full tree) + (internal node with 2 leafs)
|
||||
val trees = Array(parent, grandParent).map { root =>
|
||||
new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
|
||||
}
|
||||
val importances: Vector = RandomForest.featureImportances(trees, 2)
|
||||
val tree2norm = feature0importance + feature1importance
|
||||
val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
|
||||
(feature1importance / tree2norm) / 2.0)
|
||||
assert(importances ~== expected relTol 0.01)
|
||||
}
|
||||
|
||||
test("normalizeMapValues") {
|
||||
val map = new OpenHashMap[Int, Double]()
|
||||
map(0) = 1.0
|
||||
map(2) = 2.0
|
||||
RandomForest.normalizeMapValues(map)
|
||||
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
|
||||
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private object RandomForestSuite {
|
||||
|
||||
def mapToVec(map: Map[Int, Double]): Vector = {
|
||||
val size = (map.keys.toSeq :+ 0).max + 1
|
||||
val (indices, values) = map.toSeq.sortBy(_._1).unzip
|
||||
Vectors.sparse(size, indices.toArray, values.toArray)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue