From ebd899b8a865395e6f1137163cb508086696879b Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sun, 7 Oct 2018 10:06:44 -0700 Subject: [PATCH] [SPARK-25321][ML] Revert SPARK-14681 to avoid API breaking change ## What changes were proposed in this pull request? This is the same as #22492 but for master branch. Revert SPARK-14681 to avoid API breaking changes. cc: WeichenXu123 ## How was this patch tested? Existing unit tests. Closes #22618 from mengxr/SPARK-25321.master. Authored-by: WeichenXu Signed-off-by: Dongjoon Hyun --- .../DecisionTreeClassifier.scala | 14 +- .../ml/classification/GBTClassifier.scala | 6 +- .../RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 249 ++++-------------- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 +-- .../DecisionTreeClassifierSuite.scala | 31 +-- .../classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 14 - .../ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 7 - 16 files changed, 108 insertions(+), 333 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8a57bfc029..6648e78d8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -168,7 +168,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: ClassificationNode, + @Since("1.4.0")override val rootNode: Node, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -279,9 +279,8 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) - val model = new DecisionTreeClassificationModel(metadata.uid, - root.asInstanceOf[ClassificationNode], numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession) + val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) metadata.getAndSetParams(model) model } @@ -296,10 +295,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, - rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) + new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 33acd99140..62cfa39746 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -412,14 +412,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeRegressionModel(treeMetadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 94887ac346..57132381b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -313,15 +313,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeClassificationModel(treeMetadata.uid, - root.asInstanceOf[ClassificationNode], numFeatures, numClasses) + val tree = + new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 018290f818..6fa656275c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: RegressionNode, + override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = + private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,9 +279,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) - val model = new DecisionTreeRegressionModel(metadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession) + val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) metadata.getAndSetParams(model) model } @@ -296,8 +295,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) + new DecisionTreeRegressionModel(uid, rootNode, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 3305881b0c..07f88d8d5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -338,15 +338,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeRegressionModel(treeMetadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 35875724b3..82bf66ff66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -271,13 +271,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeRegressionModel(treeMetadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 0242bc7669..d30be452a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,16 +17,14 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed trait Node extends Serializable { +sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -86,86 +84,35 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld( - oldNode: OldNode, - categoricalFeatures: Map[Int, Int], - isClassification: Boolean): Node = { + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - if (isClassification) { - new ClassificationLeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) - } else { - new RegressionLeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) - } + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - if (isClassification) { - new ClassificationInternalNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, gain = gain, - leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) - .asInstanceOf[ClassificationNode], - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) - .asInstanceOf[ClassificationNode], - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) - } else { - new RegressionInternalNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, gain = gain, - leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) - .asInstanceOf[RegressionNode], - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) - .asInstanceOf[RegressionNode], - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) - } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } -@Since("2.4.0") -sealed trait ClassificationNode extends Node { - - /** - * Get count of training examples for specified label in this node - * @param label label number in the range [0, numClasses) - */ - @Since("2.4.0") - def getLabelCount(label: Int): Double = { - require(label >= 0 && label < impurityStats.stats.length, - "label should be in the range between 0 (inclusive) " + - s"and ${impurityStats.stats.length} (exclusive).") - impurityStats.stats(label) - } -} - -@Since("2.4.0") -sealed trait RegressionNode extends Node { - - /** Number of training data points in this node */ - @Since("2.4.0") - def getCount: Double = impurityStats.stats(0) - - /** Sum over training data points of the labels in this node */ - @Since("2.4.0") - def getSum: Double = impurityStats.stats(1) - - /** Sum over training data points of the square of the labels in this node */ - @Since("2.4.0") - def getSumOfSquares: Double = impurityStats.stats(2) -} - -@Since("2.4.0") -sealed trait LeafNode extends Node { - - /** Prediction this node makes. */ - def prediction: Double - - def impurity: Double +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -188,58 +135,32 @@ sealed trait LeafNode extends Node { override private[ml] def maxSplitFeatureIndex(): Int = -1 -} - -/** - * Decision tree leaf node for classification. - */ -@Since("2.4.0") -class ClassificationLeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) - extends ClassificationNode with LeafNode { - override private[tree] def deepCopy(): Node = { - new ClassificationLeafNode(prediction, impurity, impurityStats) - } -} - -/** - * Decision tree leaf node for regression. - */ -@Since("2.4.0") -class RegressionLeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) - extends RegressionNode with LeafNode { - - override private[tree] def deepCopy(): Node = { - new RegressionLeafNode(prediction, impurity, impurityStats) + new LeafNode(prediction, impurity, impurityStats) } } /** * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. */ -@Since("2.4.0") -sealed trait InternalNode extends Node { +class InternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) extends Node { - /** - * Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - */ - def gain: Double - - /** Left-hand child node */ - def leftChild: Node - - /** Right-hand child node */ - def rightChild: Node - - /** Information about the test used to split to the left or right child. */ - def split: Split + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -284,6 +205,11 @@ sealed trait InternalNode extends Node { math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } + + override private[tree] def deepCopy(): Node = { + new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), + split, impurityStats) + } } private object InternalNode { @@ -314,57 +240,6 @@ private object InternalNode { } } -/** - * Internal Decision Tree node for regression. - */ -@Since("2.4.0") -class ClassificationInternalNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override val gain: Double, - override val leftChild: ClassificationNode, - override val rightChild: ClassificationNode, - override val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) - extends ClassificationNode with InternalNode { - - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. - - override private[tree] def deepCopy(): Node = { - new ClassificationInternalNode(prediction, impurity, gain, - leftChild.deepCopy().asInstanceOf[ClassificationNode], - rightChild.deepCopy().asInstanceOf[ClassificationNode], - split, impurityStats) - } -} - -/** - * Internal Decision Tree node for regression. - */ -@Since("2.4.0") -class RegressionInternalNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override val gain: Double, - override val leftChild: RegressionNode, - override val rightChild: RegressionNode, - override val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) - extends RegressionNode with InternalNode { - - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. - - override private[tree] def deepCopy(): Node = { - new RegressionInternalNode(prediction, impurity, gain, - leftChild.deepCopy().asInstanceOf[RegressionNode], - rightChild.deepCopy().asInstanceOf[RegressionNode], - split, impurityStats) - } -} - - /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -390,52 +265,30 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) - - def toClassificationNode(prune: Boolean = true): ClassificationNode = { - toNode(true, prune).asInstanceOf[ClassificationNode] - } - - def toRegressionNode(prune: Boolean = true): RegressionNode = { - toNode(false, prune).asInstanceOf[RegressionNode] - } + def toNode: Node = toNode(prune = true) /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(isClassification: Boolean, prune: Boolean): Node = { + def toNode(prune: Boolean = true): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(isClassification, prune), - rightChild.get.toNode(isClassification, prune)) match { + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - if (isClassification) { - new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) - } else { - new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) - } + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) case (l, r) => - if (isClassification) { - new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, - stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], - split.get, stats.impurityCalculator) - } else { - new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], - split.get, stats.impurityCalculator) - } + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) } } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - val impurity = if (stats.valid) stats.impurity else -1.0 - if (isClassification) { - new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, stats.impurityCalculator) } else { - new RegressionLeafNode(stats.impurityCalculator.predict, impurity, - stats.impurityCalculator) + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4cdd17266b..822abd2d35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -226,23 +226,23 @@ private[spark] object RandomForest extends Logging with Serializable { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), - numFeatures, strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, + strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index f027b14f1d..4aa4c3617e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,10 +219,8 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case _: LeafNode => + case n: LeafNode => // do nothing - case _ => - throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -319,8 +317,6 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) - case _ => - throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -331,7 +327,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession, isClassification: Boolean): Node = { + sparkSession: SparkSession): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -343,7 +339,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType, isClassification) + buildTreeFromNodes(data.collect(), impurityType) } /** @@ -352,8 +348,7 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String, - isClassification: Boolean): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -369,21 +364,10 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - if (isClassification) { - new ClassificationInternalNode(n.prediction, n.impurity, n.gain, - leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], - n.split.getSplit, impurityStats) - } else { - new RegressionInternalNode(n.prediction, n.impurity, n.gain, - leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], - n.split.getSplit, impurityStats) - } + new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, + n.split.getSplit, impurityStats) } else { - if (isClassification) { - new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) - } else { - new RegressionLeafNode(n.prediction, n.impurity, impurityStats) - } + new LeafNode(n.prediction, n.impurity, impurityStats) } finalNodes(n.id) = node } @@ -437,8 +421,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String, - isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -466,8 +449,7 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( - nodeData.toArray, impurityType, isClassification) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index d3dbb4e754..2930f4900d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.ClassificationLeafNode +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,8 +61,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", - new ClassificationLeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -376,32 +375,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } - - test("label/impurity stats") { - val arr = Array( - LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), - LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), - LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) - val rdd = sc.parallelize(arr) - val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) - val dt1 = new DecisionTreeClassifier() - .setImpurity("entropy") - .setMaxDepth(2) - .setMinInstancesPerNode(2) - val model1 = dt1.fit(df) - - val rootNode1 = model1.rootNode - assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) - - val dt2 = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(2) - .setMinInstancesPerNode(2) - val model2 = dt2.fit(df) - - val rootNode2 = model2.rootNode - assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) - } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e6d2a8e2b9..3049776341 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.RegressionLeafNode +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -70,7 +70,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 3062aa9f3d..ba4a9cf082 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.ClassificationLeafNode +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,8 +71,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", - new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 9ae27339b1..29a4383965 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,20 +191,6 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } - - test("label/impurity stats") { - val categoricalFeatures = Map(0 -> 2, 1 -> 2) - val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) - val dtr = new DecisionTreeRegressor() - .setImpurity("variance") - .setMaxDepth(2) - .setMaxBins(8) - val model = dtr.fit(df) - val statInfo = model.rootNode - - assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 - && statInfo.getSumOfSquares == 600.0) - } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 4dbbd75d24..743dacf146 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) - assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) - assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], - numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) + .asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 3f03d909d4..b6894b30b0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { + def buildParentNode(left: Node, right: Node, split: Split): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,15 +168,7 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - if (isClassification) { - new ClassificationInternalNode(pred, parentImp.calculate(), gain, - left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], - split, parentImp) - } else { - new RegressionInternalNode(pred, parentImp.calculate(), gain, - left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], - split, parentImp) - } + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a931738032..0b074fbf64 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -103,13 +103,6 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), - // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), - // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),