[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor

Added featureImportance to RandomForestClassifier and Regressor.

This follows the scikit-learn implementation here: [a95203b249/sklearn/tree/_tree.pyx (L3341)]

CC: yanboliang  Would you mind taking a look?  Thanks!

Author: Joseph K. Bradley <joseph@databricks.com>
Author: Feynman Liang <fliang@databricks.com>

Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits:

72a167a [Joseph K. Bradley] fixed unit test
86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map
5aa74f0 [Joseph K. Bradley] finally fixed unit test for real
33df5db [Joseph K. Bradley] fix unit test
42a2d3b [Joseph K. Bradley] fix unit test
fe94e72 [Joseph K. Bradley] modified feature importance unit tests
cc693ee [Feynman Liang] Add classifier tests
79a6f87 [Feynman Liang] Compare dense vectors in test
21d01fc [Feynman Liang] Added failing SKLearn test
ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor.  Need to add unit tests
This commit is contained in:
Joseph K. Bradley 2015-08-03 12:17:46 -07:00
parent 703e44bff1
commit ff9169a002
11 changed files with 351 additions and 16 deletions

View file

@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
new RandomForestClassificationModel(trees, numClasses)
val numFeatures = oldDataset.first().features.size
new RandomForestClassificationModel(trees, numFeatures, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@ -118,11 +119,13 @@ object RandomForestClassifier {
* features.
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
* @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
this(Identifiable.randomUID("rfc"), trees, numClasses)
def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
}
override def toString: String = {
s"RandomForestClassificationModel with $numTrees trees"
}
/**
* Estimate of the importance of each feature.
*
* This generalizes the idea of "Gini" importance to other losses,
* following the explanation of Gini importance from "Random Forests" documentation
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
*
* This feature importance is calculated as follows:
* - Average over trees:
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
* where gain is scaled by the number of instances passing through node
* - Normalize importances for tree based on total number of training instances used
* to build tree.
* - Normalize feature importance vector to sum to 1.
*/
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
new RandomForestClassificationModel(uid, newTrees, numClasses)
new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
}
}

View file

@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
new RandomForestRegressionModel(trees)
val numFeatures = oldDataset.first().features.size
new RandomForestRegressionModel(trees, numFeatures)
}
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
@ -108,11 +109,13 @@ object RandomForestRegressor {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestRegressionModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel])
private val _trees: Array[DecisionTreeRegressionModel],
val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] (
* Construct a random forest regression model, with all trees weighted equally.
* @param trees Component trees
*/
def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
this(Identifiable.randomUID("rfr"), trees, numFeatures)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees), extra)
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra)
}
override def toString: String = {
s"RandomForestRegressionModel with $numTrees trees"
}
/**
* Estimate of the importance of each feature.
*
* This generalizes the idea of "Gini" importance to other losses,
* following the explanation of Gini importance from "Random Forests" documentation
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
*
* This feature importance is calculated as follows:
* - Average over trees:
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
* where gain is scaled by the number of instances passing through node
* - Normalize importances for tree based on total number of training instances used
* to build tree.
* - Normalize feature importance vector to sum to 1.
*/
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new RandomForestRegressionModel(parent.uid, newTrees)
new RandomForestRegressionModel(parent.uid, newTrees, -1)
}
}

View file

@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable {
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
private[tree] def impurityStats: ImpurityCalculator
private[ml] def impurityStats: ImpurityCalculator
/** Recursive prediction helper method */
private[ml] def predictImpl(features: Vector): LeafNode
@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable {
* @param id Node ID using old format IDs
*/
private[ml] def toOld(id: Int): OldNode
/**
* Trace down the tree, and return the largest feature index used in any split.
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
*/
private[ml] def maxSplitFeatureIndex(): Int
}
private[ml] object Node {
@ -109,7 +115,7 @@ private[ml] object Node {
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override val impurityStats: ImpurityCalculator) extends Node {
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
@ -129,6 +135,8 @@ final class LeafNode private[ml] (
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
impurity, isLeaf = true, None, None, None, None)
}
override private[ml] def maxSplitFeatureIndex(): Int = -1
}
/**
@ -150,7 +158,7 @@ final class InternalNode private[ml] (
val leftChild: Node,
val rightChild: Node,
val split: Split,
override val impurityStats: ImpurityCalculator) extends Node {
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
@ -190,6 +198,11 @@ final class InternalNode private[ml] (
new OldPredict(leftChild.prediction, prob = 0.0),
new OldPredict(rightChild.prediction, prob = 0.0))))
}
override private[ml] def maxSplitFeatureIndex(): Int = {
math.max(split.featureIndex,
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
}
}
private object InternalNode {

View file

@ -26,6 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging {
}
}
/**
* Given a Random Forest model, compute the importance of each feature.
* This generalizes the idea of "Gini" importance to other losses,
* following the explanation of Gini importance from "Random Forests" documentation
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
*
* This feature importance is calculated as follows:
* - Average over trees:
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
* where gain is scaled by the number of instances passing through node
* - Normalize importances for tree based on total number of training instances used
* to build tree.
* - Normalize feature importance vector to sum to 1.
*
* Note: This should not be used with Gradient-Boosted Trees. It only makes sense for
* independently trained trees.
* @param trees Unweighted forest of trees
* @param numFeatures Number of features in model (even if not all are explicitly used by
* the model).
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
val totalImportances = new OpenHashMap[Int, Double]()
trees.foreach { tree =>
// Aggregate feature importance vector for this tree
val importances = new OpenHashMap[Int, Double]()
computeFeatureImportance(tree.rootNode, importances)
// Normalize importance vector for this tree, and add it to total.
// TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
val treeNorm = importances.map(_._2).sum
if (treeNorm != 0) {
importances.foreach { case (idx, impt) =>
val normImpt = impt / treeNorm
totalImportances.changeValue(idx, normImpt, _ + normImpt)
}
}
}
// Normalize importances
normalizeMapValues(totalImportances)
// Construct vector
val d = if (numFeatures != -1) {
numFeatures
} else {
// Find max feature index used in trees
val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
maxFeatureIndex + 1
}
if (d == 0) {
assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
s" importance: No splits in forest, but some non-zero importances.")
}
val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
Vectors.sparse(d, indices.toArray, values.toArray)
}
/**
* Recursive method for computing feature importances for one tree.
* This walks down the tree, adding to the importance of 1 feature at each node.
* @param node Current node in recursion
* @param importances Aggregate feature importances, modified by this method
*/
private[impl] def computeFeatureImportance(
node: Node,
importances: OpenHashMap[Int, Double]): Unit = {
node match {
case n: InternalNode =>
val feature = n.split.featureIndex
val scaledGain = n.gain * n.impurityStats.count
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
case n: LeafNode =>
// do nothing
}
}
/**
* Normalize the values of this map to sum to 1, in place.
* If all values are 0, this method does nothing.
* @param map Map with non-negative values.
*/
private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
val total = map.map(_._2).sum
if (total != 0) {
val keys = map.iterator.map(_._1).toArray
keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
}
}
}

View file

@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel {
val header = toString + "\n"
header + rootNode.subtreeToString(2)
}
/**
* Trace down the tree, and return the largest feature index used in any split.
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
*/
private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
}
/**

View file

@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
model.toDebugString();
model.trees();
model.treeWeights();
Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725

View file

@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
model.toDebugString();
model.trees();
model.treeWeights();
Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725

View file

@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
ParamsSuite.checkParams(model)
}
@ -149,6 +149,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
}
}
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
test("Feature importance with toy data") {
val numClasses = 2
val rf = new RandomForestClassifier()
.setImpurity("Gini")
.setMaxDepth(3)
.setNumTrees(3)
.setFeatureSubsetStrategy("all")
.setSubsamplingRate(1.0)
.setSeed(123)
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = sc.parallelize(Seq(
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
))
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val importances = rf.fit(df).featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
}
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////

View file

@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
"checkEqual failed since the two tree ensembles were not identical")
}
}
/**
* Helper method for constructing a tree for testing.
* Given left, right children, construct a parent node.
* @param split Split for parent node
* @return Parent node with children attached
*/
def buildParentNode(left: Node, right: Node, split: Split): Node = {
val leftImp = left.impurityStats
val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp)
val leftWeight = leftImp.count / parentImp.count.toDouble
val rightWeight = rightImp.count / parentImp.count.toDouble
val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict
new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
}
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}
test("Feature importance with toy data") {
val rf = new RandomForestRegressor()
.setImpurity("variance")
.setMaxDepth(3)
.setNumTrees(3)
.setFeatureSubsetStrategy("all")
.setSubsamplingRate(1.0)
.setSeed(123)
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = sc.parallelize(Seq(
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
))
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
val importances = rf.fit(df).featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
}
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.GiniCalculator
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.collection.OpenHashMap
/**
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestSuite.mapToVec
test("computeFeatureImportance, featureImportances") {
/* Build tree for testing, with this structure:
grandParent
left2 parent
left right
*/
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
val parentImp = parent.impurityStats
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
val grandImp = grandParent.impurityStats
// Test feature importance computed at different subtrees.
def testNode(node: Node, expected: Map[Int, Double]): Unit = {
val map = new OpenHashMap[Int, Double]()
RandomForest.computeFeatureImportance(node, map)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
// Leaf node
testNode(left, Map.empty[Int, Double])
// Internal node with 2 leaf children
val feature0importance = parentImp.calculate() * parentImp.count -
(leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count)
testNode(parent, Map(0 -> feature0importance))
// Full tree
val feature1importance = grandImp.calculate() * grandImp.count -
(left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count)
testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance))
// Forest consisting of (full tree) + (internal node with 2 leafs)
val trees = Array(parent, grandParent).map { root =>
new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
}
val importances: Vector = RandomForest.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
(feature1importance / tree2norm) / 2.0)
assert(importances ~== expected relTol 0.01)
}
test("normalizeMapValues") {
val map = new OpenHashMap[Int, Double]()
map(0) = 1.0
map(2) = 2.0
RandomForest.normalizeMapValues(map)
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
}
private object RandomForestSuite {
def mapToVec(map: Map[Int, Double]): Vector = {
val size = (map.keys.toSeq :+ 0).max + 1
val (indices, values) = map.toSeq.sortBy(_._1).unzip
Vectors.sparse(size, indices.toArray, values.toArray)
}
}