diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b..d30be452a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, - InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. @@ -266,15 +265,23 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { + def toNode: Node = toNode(prune = true) + /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode: Node = { - if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + def toNode(prune: Boolean = true): Node = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + case (l, r) => + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) + } } else { if (stats.valid) { new LeafNode(stats.impurityCalculator.predict, stats.impurity, @@ -283,7 +290,6 @@ private[tree] class LearningNode( // Here we want to keep same behavior with the old mllib.DecisionTreeModel new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c5..8e514f11e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation[_]], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 98c879ece6..38b265d626 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite dt.fit(df) } - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val data = sc.parallelize(arr) - val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val dt = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(1) - .setMaxBins(3) - val model = dt.fit(df) - model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case other => - fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") - } - case other => - fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") - } - } - test("Feature importance with toy data") { val dt = new DecisionTreeClassifier() .setImpurity("gini") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931f..5f0d26eb5c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + private val seed = 42 + ///////////////////////////////////////////////////////////////////////////// // Tests for split calculation ///////////////////////////////////////////////////////////////////////////// @@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42, instr = None, prune = false).head + model.rootNode match { case n: InternalNode => n.split match { case s: CategoricalSplit => @@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip Vectors.sparse(size, indices.toArray, values.toArray) } + + @tailrec + private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { + if (nodes.isEmpty) { + acc + } + else { + nodes.head match { + case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 441d0f7614..bc59f3f412 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { - if (i < 1000) { + if (i < 1001) { arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))