[SPARK-3159][ML] Add decision tree pruning

## What changes were proposed in this pull request?

Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode

## How was this patch tested?

Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala":
- test("SPARK-3159 tree model redundancy - classification")
- test("SPARK-3159 tree model redundancy - regression")

4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra _prune_ parameter which can be used to disable pruning for testing)

Author: Alessandro Solimando <18898964+asolimando@users.noreply.github.com>

Closes #20632 from asolimando/master.
This commit is contained in:
Alessandro Solimando 2018-03-02 16:24:29 -08:00 committed by sethah
parent 487377e693
commit 9e26473c0f
5 changed files with 115 additions and 65 deletions

View file

@ -19,8 +19,7 @@ package org.apache.spark.ml.tree
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{ImpurityStats,
InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
/**
* Decision tree node interface.
@ -266,15 +265,23 @@ private[tree] class LearningNode(
var isLeaf: Boolean,
var stats: ImpurityStats) extends Serializable {
def toNode: Node = toNode(prune = true)
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode: Node = {
if (leftChild.nonEmpty) {
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
def toNode(prune: Boolean = true): Node = {
if (!leftChild.isEmpty || !rightChild.isEmpty) {
assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
(leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
case (l, r) =>
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
l, r, split.get, stats.impurityCalculator)
}
} else {
if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
@ -283,7 +290,6 @@ private[tree] class LearningNode(
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}
}
}

View file

@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()
@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map { rootNode =>
new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
}
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
topNodes.map(rootNode =>
new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
}
}
}

View file

@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite
dt.fit(df)
}
test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxDepth(1)
.setMaxBins(3)
val model = dt.fit(df)
model.rootNode match {
case n: InternalNode =>
n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
case other =>
fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.")
}
case other =>
fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.")
}
}
test("Feature importance with toy data") {
val dt = new DecisionTreeClassifier()
.setImpurity("gini")

View file

@ -17,6 +17,7 @@
package org.apache.spark.ml.tree.impl
import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.SparkFunSuite
@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestSuite.mapToVec
private val seed = 42
/////////////////////////////////////////////////////////////////////////////
// Tests for split calculation
/////////////////////////////////////////////////////////////////////////////
@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.isLeaf === false)
assert(topNode.stats === null)
val nodesForGroup = Map((0, Array(topNode)))
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodesForGroup = Map(0 -> Array(topNode))
val treeToNodeToIndexInfo = Map(0 -> Map(
topNode.id -> new RandomForest.NodeIndexInfo(0, None)
))
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.isLeaf === false)
assert(topNode.stats === null)
val nodesForGroup = Map((0, Array(topNode)))
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodesForGroup = Map(0 -> Array(topNode))
val treeToNodeToIndexInfo = Map(0 -> Map(
topNode.id -> new RandomForest.NodeIndexInfo(0, None)
))
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 42, instr = None).head
seed = 42, instr = None, prune = false).head
model.rootNode match {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
///////////////////////////////////////////////////////////////////////////////
// Tests for pruning of redundant subtrees (generated by a split improving the
// impurity measure, but always leading to the same prediction).
///////////////////////////////////////////////////////////////////////////////
test("SPARK-3159 tree model redundancy - classification") {
// The following dataset is set up such that splitting over feature_1 for points having
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
// in both branches.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
)
val rdd = sc.parallelize(arr)
val numClasses = 2
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4,
numClasses = numClasses, maxBins = 32)
val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None).head
val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None, prune = false).head
assert(prunedTree.numNodes === 5)
assert(unprunedTree.numNodes === 7)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}
test("SPARK-3159 tree model redundancy - regression") {
// The following dataset is set up such that splitting over feature_0 for points having
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
// in both branches.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4,
numClasses = 0, maxBins = 32)
val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None).head
val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None, prune = false).head
assert(prunedTree.numNodes === 3)
assert(unprunedTree.numNodes === 5)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}
}
private object RandomForestSuite {
def mapToVec(map: Map[Int, Double]): Vector = {
val size = (map.keys.toSeq :+ 0).max + 1
val (indices, values) = map.toSeq.sortBy(_._1).unzip
Vectors.sparse(size, indices.toArray, values.toArray)
}
@tailrec
private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = {
if (nodes.isEmpty) {
acc
}
else {
nodes.head match {
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
}
}
}
}

View file

@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 0.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite {
Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 1000) {
if (i < 1001) {
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))