[SPARK-25321][ML] Revert SPARK-14681 to avoid API breaking change
## What changes were proposed in this pull request? This is the same as #22492 but for master branch. Revert SPARK-14681 to avoid API breaking changes. cc: WeichenXu123 ## How was this patch tested? Existing unit tests. Closes #22618 from mengxr/SPARK-25321.master. Authored-by: WeichenXu <weichen.xu@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
parent
669ade3a8e
commit
ebd899b8a8
|
@ -168,7 +168,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
|
|||
@Since("1.4.0")
|
||||
class DecisionTreeClassificationModel private[ml] (
|
||||
@Since("1.4.0")override val uid: String,
|
||||
@Since("1.4.0")override val rootNode: ClassificationNode,
|
||||
@Since("1.4.0")override val rootNode: Node,
|
||||
@Since("1.6.0")override val numFeatures: Int,
|
||||
@Since("1.5.0")override val numClasses: Int)
|
||||
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
|
||||
|
@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
|
|||
* Construct a decision tree classification model.
|
||||
* @param rootNode Root node of tree, with other nodes attached.
|
||||
*/
|
||||
private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) =
|
||||
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
|
||||
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
|
@ -279,9 +279,8 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
|
|||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
|
||||
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
|
||||
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
|
||||
val model = new DecisionTreeClassificationModel(metadata.uid,
|
||||
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
|
||||
val root = loadTreeNodes(path, metadata, sparkSession)
|
||||
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
|
||||
metadata.getAndSetParams(model)
|
||||
model
|
||||
}
|
||||
|
@ -296,10 +295,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
|
|||
require(oldModel.algo == OldAlgo.Classification,
|
||||
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
|
||||
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
|
||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true)
|
||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
|
||||
// Can't infer number of features from old model, so default to -1
|
||||
new DecisionTreeClassificationModel(uid,
|
||||
rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
|
||||
new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -412,14 +412,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
|
|||
override def load(path: String): GBTClassificationModel = {
|
||||
implicit val format = DefaultFormats
|
||||
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
|
||||
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
|
||||
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
|
||||
|
||||
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
|
||||
case (treeMetadata, root) =>
|
||||
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
|
||||
root.asInstanceOf[RegressionNode], numFeatures)
|
||||
val tree =
|
||||
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
|
||||
treeMetadata.getAndSetParams(tree)
|
||||
tree
|
||||
}
|
||||
|
|
|
@ -313,15 +313,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
|
|||
override def load(path: String): RandomForestClassificationModel = {
|
||||
implicit val format = DefaultFormats
|
||||
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true)
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
|
||||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
|
||||
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
|
||||
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
|
||||
|
||||
val trees: Array[DecisionTreeClassificationModel] = treesData.map {
|
||||
case (treeMetadata, root) =>
|
||||
val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
|
||||
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
|
||||
val tree =
|
||||
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
|
||||
treeMetadata.getAndSetParams(tree)
|
||||
tree
|
||||
}
|
||||
|
|
|
@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
|
|||
@Since("1.4.0")
|
||||
class DecisionTreeRegressionModel private[ml] (
|
||||
override val uid: String,
|
||||
override val rootNode: RegressionNode,
|
||||
override val rootNode: Node,
|
||||
override val numFeatures: Int)
|
||||
extends PredictionModel[Vector, DecisionTreeRegressionModel]
|
||||
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
|
||||
|
@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] (
|
|||
* Construct a decision tree regression model.
|
||||
* @param rootNode Root node of tree, with other nodes attached.
|
||||
*/
|
||||
private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
|
||||
private[ml] def this(rootNode: Node, numFeatures: Int) =
|
||||
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
|
@ -279,9 +279,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
|
|||
implicit val format = DefaultFormats
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
|
||||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
|
||||
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false)
|
||||
val model = new DecisionTreeRegressionModel(metadata.uid,
|
||||
root.asInstanceOf[RegressionNode], numFeatures)
|
||||
val root = loadTreeNodes(path, metadata, sparkSession)
|
||||
val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
|
||||
metadata.getAndSetParams(model)
|
||||
model
|
||||
}
|
||||
|
@ -296,8 +295,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
|
|||
require(oldModel.algo == OldAlgo.Regression,
|
||||
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
|
||||
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
|
||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false)
|
||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
|
||||
new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures)
|
||||
new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -338,15 +338,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
|
|||
override def load(path: String): GBTRegressionModel = {
|
||||
implicit val format = DefaultFormats
|
||||
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
|
||||
|
||||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
|
||||
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
|
||||
|
||||
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
|
||||
case (treeMetadata, root) =>
|
||||
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
|
||||
root.asInstanceOf[RegressionNode], numFeatures)
|
||||
val tree =
|
||||
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
|
||||
treeMetadata.getAndSetParams(tree)
|
||||
tree
|
||||
}
|
||||
|
|
|
@ -271,13 +271,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
|
|||
override def load(path: String): RandomForestRegressionModel = {
|
||||
implicit val format = DefaultFormats
|
||||
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
|
||||
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
|
||||
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
|
||||
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
|
||||
|
||||
val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
|
||||
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
|
||||
root.asInstanceOf[RegressionNode], numFeatures)
|
||||
val tree =
|
||||
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
|
||||
treeMetadata.getAndSetParams(tree)
|
||||
tree
|
||||
}
|
||||
|
|
|
@ -17,16 +17,14 @@
|
|||
|
||||
package org.apache.spark.ml.tree
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats,
|
||||
Node => OldNode, Predict => OldPredict}
|
||||
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
|
||||
|
||||
/**
|
||||
* Decision tree node interface.
|
||||
*/
|
||||
sealed trait Node extends Serializable {
|
||||
sealed abstract class Node extends Serializable {
|
||||
|
||||
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
|
||||
// code into the new API and deprecate the old API. SPARK-3727
|
||||
|
@ -86,86 +84,35 @@ private[ml] object Node {
|
|||
/**
|
||||
* Create a new Node from the old Node format, recursively creating child nodes as needed.
|
||||
*/
|
||||
def fromOld(
|
||||
oldNode: OldNode,
|
||||
categoricalFeatures: Map[Int, Int],
|
||||
isClassification: Boolean): Node = {
|
||||
def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
|
||||
if (oldNode.isLeaf) {
|
||||
// TODO: Once the implementation has been moved to this API, then include sufficient
|
||||
// statistics here.
|
||||
if (isClassification) {
|
||||
new ClassificationLeafNode(prediction = oldNode.predict.predict,
|
||||
impurity = oldNode.impurity, impurityStats = null)
|
||||
} else {
|
||||
new RegressionLeafNode(prediction = oldNode.predict.predict,
|
||||
impurity = oldNode.impurity, impurityStats = null)
|
||||
}
|
||||
new LeafNode(prediction = oldNode.predict.predict,
|
||||
impurity = oldNode.impurity, impurityStats = null)
|
||||
} else {
|
||||
val gain = if (oldNode.stats.nonEmpty) {
|
||||
oldNode.stats.get.gain
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
if (isClassification) {
|
||||
new ClassificationInternalNode(prediction = oldNode.predict.predict,
|
||||
impurity = oldNode.impurity, gain = gain,
|
||||
leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true)
|
||||
.asInstanceOf[ClassificationNode],
|
||||
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true)
|
||||
.asInstanceOf[ClassificationNode],
|
||||
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
|
||||
} else {
|
||||
new RegressionInternalNode(prediction = oldNode.predict.predict,
|
||||
impurity = oldNode.impurity, gain = gain,
|
||||
leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false)
|
||||
.asInstanceOf[RegressionNode],
|
||||
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false)
|
||||
.asInstanceOf[RegressionNode],
|
||||
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
|
||||
}
|
||||
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
|
||||
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
|
||||
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
|
||||
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Since("2.4.0")
|
||||
sealed trait ClassificationNode extends Node {
|
||||
|
||||
/**
|
||||
* Get count of training examples for specified label in this node
|
||||
* @param label label number in the range [0, numClasses)
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
def getLabelCount(label: Int): Double = {
|
||||
require(label >= 0 && label < impurityStats.stats.length,
|
||||
"label should be in the range between 0 (inclusive) " +
|
||||
s"and ${impurityStats.stats.length} (exclusive).")
|
||||
impurityStats.stats(label)
|
||||
}
|
||||
}
|
||||
|
||||
@Since("2.4.0")
|
||||
sealed trait RegressionNode extends Node {
|
||||
|
||||
/** Number of training data points in this node */
|
||||
@Since("2.4.0")
|
||||
def getCount: Double = impurityStats.stats(0)
|
||||
|
||||
/** Sum over training data points of the labels in this node */
|
||||
@Since("2.4.0")
|
||||
def getSum: Double = impurityStats.stats(1)
|
||||
|
||||
/** Sum over training data points of the square of the labels in this node */
|
||||
@Since("2.4.0")
|
||||
def getSumOfSquares: Double = impurityStats.stats(2)
|
||||
}
|
||||
|
||||
@Since("2.4.0")
|
||||
sealed trait LeafNode extends Node {
|
||||
|
||||
/** Prediction this node makes. */
|
||||
def prediction: Double
|
||||
|
||||
def impurity: Double
|
||||
/**
|
||||
* Decision tree leaf node.
|
||||
* @param prediction Prediction this node makes
|
||||
* @param impurity Impurity measure at this node (for training data)
|
||||
*/
|
||||
class LeafNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
|
||||
|
||||
override def toString: String =
|
||||
s"LeafNode(prediction = $prediction, impurity = $impurity)"
|
||||
|
@ -188,58 +135,32 @@ sealed trait LeafNode extends Node {
|
|||
|
||||
override private[ml] def maxSplitFeatureIndex(): Int = -1
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Decision tree leaf node for classification.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
class ClassificationLeafNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
override private[ml] val impurityStats: ImpurityCalculator)
|
||||
extends ClassificationNode with LeafNode {
|
||||
|
||||
override private[tree] def deepCopy(): Node = {
|
||||
new ClassificationLeafNode(prediction, impurity, impurityStats)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decision tree leaf node for regression.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
class RegressionLeafNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
override private[ml] val impurityStats: ImpurityCalculator)
|
||||
extends RegressionNode with LeafNode {
|
||||
|
||||
override private[tree] def deepCopy(): Node = {
|
||||
new RegressionLeafNode(prediction, impurity, impurityStats)
|
||||
new LeafNode(prediction, impurity, impurityStats)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal Decision Tree node.
|
||||
* @param prediction Prediction this node would make if it were a leaf node
|
||||
* @param impurity Impurity measure at this node (for training data)
|
||||
* @param gain Information gain value. Values less than 0 indicate missing values;
|
||||
* this quirk will be removed with future updates.
|
||||
* @param leftChild Left-hand child node
|
||||
* @param rightChild Right-hand child node
|
||||
* @param split Information about the test used to split to the left or right child.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
sealed trait InternalNode extends Node {
|
||||
class InternalNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
val gain: Double,
|
||||
val leftChild: Node,
|
||||
val rightChild: Node,
|
||||
val split: Split,
|
||||
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
|
||||
|
||||
/**
|
||||
* Information gain value. Values less than 0 indicate missing values;
|
||||
* this quirk will be removed with future updates.
|
||||
*/
|
||||
def gain: Double
|
||||
|
||||
/** Left-hand child node */
|
||||
def leftChild: Node
|
||||
|
||||
/** Right-hand child node */
|
||||
def rightChild: Node
|
||||
|
||||
/** Information about the test used to split to the left or right child. */
|
||||
def split: Split
|
||||
// Note to developers: The constructor argument impurityStats should be reconsidered before we
|
||||
// make the constructor public. We may be able to improve the representation.
|
||||
|
||||
override def toString: String = {
|
||||
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
|
||||
|
@ -284,6 +205,11 @@ sealed trait InternalNode extends Node {
|
|||
math.max(split.featureIndex,
|
||||
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
|
||||
}
|
||||
|
||||
override private[tree] def deepCopy(): Node = {
|
||||
new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(),
|
||||
split, impurityStats)
|
||||
}
|
||||
}
|
||||
|
||||
private object InternalNode {
|
||||
|
@ -314,57 +240,6 @@ private object InternalNode {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal Decision Tree node for regression.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
class ClassificationInternalNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
override val gain: Double,
|
||||
override val leftChild: ClassificationNode,
|
||||
override val rightChild: ClassificationNode,
|
||||
override val split: Split,
|
||||
override private[ml] val impurityStats: ImpurityCalculator)
|
||||
extends ClassificationNode with InternalNode {
|
||||
|
||||
// Note to developers: The constructor argument impurityStats should be reconsidered before we
|
||||
// make the constructor public. We may be able to improve the representation.
|
||||
|
||||
override private[tree] def deepCopy(): Node = {
|
||||
new ClassificationInternalNode(prediction, impurity, gain,
|
||||
leftChild.deepCopy().asInstanceOf[ClassificationNode],
|
||||
rightChild.deepCopy().asInstanceOf[ClassificationNode],
|
||||
split, impurityStats)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal Decision Tree node for regression.
|
||||
*/
|
||||
@Since("2.4.0")
|
||||
class RegressionInternalNode private[ml] (
|
||||
override val prediction: Double,
|
||||
override val impurity: Double,
|
||||
override val gain: Double,
|
||||
override val leftChild: RegressionNode,
|
||||
override val rightChild: RegressionNode,
|
||||
override val split: Split,
|
||||
override private[ml] val impurityStats: ImpurityCalculator)
|
||||
extends RegressionNode with InternalNode {
|
||||
|
||||
// Note to developers: The constructor argument impurityStats should be reconsidered before we
|
||||
// make the constructor public. We may be able to improve the representation.
|
||||
|
||||
override private[tree] def deepCopy(): Node = {
|
||||
new RegressionInternalNode(prediction, impurity, gain,
|
||||
leftChild.deepCopy().asInstanceOf[RegressionNode],
|
||||
rightChild.deepCopy().asInstanceOf[RegressionNode],
|
||||
split, impurityStats)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Version of a node used in learning. This uses vars so that we can modify nodes as we split the
|
||||
* tree by adding children, etc.
|
||||
|
@ -390,52 +265,30 @@ private[tree] class LearningNode(
|
|||
var isLeaf: Boolean,
|
||||
var stats: ImpurityStats) extends Serializable {
|
||||
|
||||
def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true)
|
||||
|
||||
def toClassificationNode(prune: Boolean = true): ClassificationNode = {
|
||||
toNode(true, prune).asInstanceOf[ClassificationNode]
|
||||
}
|
||||
|
||||
def toRegressionNode(prune: Boolean = true): RegressionNode = {
|
||||
toNode(false, prune).asInstanceOf[RegressionNode]
|
||||
}
|
||||
def toNode: Node = toNode(prune = true)
|
||||
|
||||
/**
|
||||
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
|
||||
*/
|
||||
def toNode(isClassification: Boolean, prune: Boolean): Node = {
|
||||
def toNode(prune: Boolean = true): Node = {
|
||||
|
||||
if (!leftChild.isEmpty || !rightChild.isEmpty) {
|
||||
assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
|
||||
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
|
||||
(leftChild.get.toNode(isClassification, prune),
|
||||
rightChild.get.toNode(isClassification, prune)) match {
|
||||
(leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
|
||||
case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
|
||||
if (isClassification) {
|
||||
new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
|
||||
} else {
|
||||
new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
|
||||
}
|
||||
new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
|
||||
case (l, r) =>
|
||||
if (isClassification) {
|
||||
new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity,
|
||||
stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode],
|
||||
split.get, stats.impurityCalculator)
|
||||
} else {
|
||||
new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
|
||||
l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode],
|
||||
split.get, stats.impurityCalculator)
|
||||
}
|
||||
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
|
||||
l, r, split.get, stats.impurityCalculator)
|
||||
}
|
||||
} else {
|
||||
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
|
||||
val impurity = if (stats.valid) stats.impurity else -1.0
|
||||
if (isClassification) {
|
||||
new ClassificationLeafNode(stats.impurityCalculator.predict, impurity,
|
||||
if (stats.valid) {
|
||||
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
|
||||
stats.impurityCalculator)
|
||||
} else {
|
||||
new RegressionLeafNode(stats.impurityCalculator.predict, impurity,
|
||||
stats.impurityCalculator)
|
||||
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
|
||||
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -226,23 +226,23 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
case Some(uid) =>
|
||||
if (strategy.algo == OldAlgo.Classification) {
|
||||
topNodes.map { rootNode =>
|
||||
new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune),
|
||||
numFeatures, strategy.getNumClasses)
|
||||
new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
|
||||
strategy.getNumClasses)
|
||||
}
|
||||
} else {
|
||||
topNodes.map { rootNode =>
|
||||
new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures)
|
||||
new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
|
||||
}
|
||||
}
|
||||
case None =>
|
||||
if (strategy.algo == OldAlgo.Classification) {
|
||||
topNodes.map { rootNode =>
|
||||
new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures,
|
||||
new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
|
||||
strategy.getNumClasses)
|
||||
}
|
||||
} else {
|
||||
topNodes.map(rootNode =>
|
||||
new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures))
|
||||
new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -219,10 +219,8 @@ private[ml] object TreeEnsembleModel {
|
|||
importances.changeValue(feature, scaledGain, _ + scaledGain)
|
||||
computeFeatureImportance(n.leftChild, importances)
|
||||
computeFeatureImportance(n.rightChild, importances)
|
||||
case _: LeafNode =>
|
||||
case n: LeafNode =>
|
||||
// do nothing
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,8 +317,6 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
|
||||
-1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
|
||||
id)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -331,7 +327,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
def loadTreeNodes(
|
||||
path: String,
|
||||
metadata: DefaultParamsReader.Metadata,
|
||||
sparkSession: SparkSession, isClassification: Boolean): Node = {
|
||||
sparkSession: SparkSession): Node = {
|
||||
import sparkSession.implicits._
|
||||
implicit val format = DefaultFormats
|
||||
|
||||
|
@ -343,7 +339,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val data = sparkSession.read.parquet(dataPath).as[NodeData]
|
||||
buildTreeFromNodes(data.collect(), impurityType, isClassification)
|
||||
buildTreeFromNodes(data.collect(), impurityType)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -352,8 +348,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
* @param impurityType Impurity type for this tree
|
||||
* @return Root node of reconstructed tree
|
||||
*/
|
||||
def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
|
||||
isClassification: Boolean): Node = {
|
||||
def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
|
||||
// Load all nodes, sorted by ID.
|
||||
val nodes = data.sortBy(_.id)
|
||||
// Sanity checks; could remove
|
||||
|
@ -369,21 +364,10 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
val node = if (n.leftChild != -1) {
|
||||
val leftChild = finalNodes(n.leftChild)
|
||||
val rightChild = finalNodes(n.rightChild)
|
||||
if (isClassification) {
|
||||
new ClassificationInternalNode(n.prediction, n.impurity, n.gain,
|
||||
leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode],
|
||||
n.split.getSplit, impurityStats)
|
||||
} else {
|
||||
new RegressionInternalNode(n.prediction, n.impurity, n.gain,
|
||||
leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode],
|
||||
n.split.getSplit, impurityStats)
|
||||
}
|
||||
new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild,
|
||||
n.split.getSplit, impurityStats)
|
||||
} else {
|
||||
if (isClassification) {
|
||||
new ClassificationLeafNode(n.prediction, n.impurity, impurityStats)
|
||||
} else {
|
||||
new RegressionLeafNode(n.prediction, n.impurity, impurityStats)
|
||||
}
|
||||
new LeafNode(n.prediction, n.impurity, impurityStats)
|
||||
}
|
||||
finalNodes(n.id) = node
|
||||
}
|
||||
|
@ -437,8 +421,7 @@ private[ml] object EnsembleModelReadWrite {
|
|||
path: String,
|
||||
sql: SparkSession,
|
||||
className: String,
|
||||
treeClassName: String,
|
||||
isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
|
||||
treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
|
||||
import sql.implicits._
|
||||
implicit val format = DefaultFormats
|
||||
val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
|
||||
|
@ -466,8 +449,7 @@ private[ml] object EnsembleModelReadWrite {
|
|||
val rootNodesRDD: RDD[(Int, Node)] =
|
||||
nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
|
||||
case (treeID: Int, nodeData: Iterable[NodeData]) =>
|
||||
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(
|
||||
nodeData.toArray, impurityType, isClassification)
|
||||
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
|
||||
}
|
||||
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
|
||||
(metadata, treesMetadata.zip(rootNodes), treesWeights)
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.tree.ClassificationLeafNode
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
import org.apache.spark.ml.tree.impl.TreeTests
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
|
||||
|
@ -61,8 +61,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
||||
val model = new DecisionTreeClassificationModel("dtc",
|
||||
new ClassificationLeafNode(0.0, 0.0, null), 1, 2)
|
||||
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
|
@ -376,32 +375,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
|
||||
testDefaultReadWrite(model)
|
||||
}
|
||||
|
||||
test("label/impurity stats") {
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
|
||||
LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
|
||||
LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
|
||||
val rdd = sc.parallelize(arr)
|
||||
val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2)
|
||||
val dt1 = new DecisionTreeClassifier()
|
||||
.setImpurity("entropy")
|
||||
.setMaxDepth(2)
|
||||
.setMinInstancesPerNode(2)
|
||||
val model1 = dt1.fit(df)
|
||||
|
||||
val rootNode1 = model1.rootNode
|
||||
assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0))
|
||||
|
||||
val dt2 = new DecisionTreeClassifier()
|
||||
.setImpurity("gini")
|
||||
.setMaxDepth(2)
|
||||
.setMinInstancesPerNode(2)
|
||||
val model2 = dt2.fit(df)
|
||||
|
||||
val rootNode2 = model2.rootNode
|
||||
assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0))
|
||||
}
|
||||
}
|
||||
|
||||
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint
|
|||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||
import org.apache.spark.ml.tree.RegressionLeafNode
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
|
@ -70,7 +70,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
test("params") {
|
||||
ParamsSuite.checkParams(new GBTClassifier)
|
||||
val model = new GBTClassificationModel("gbtc",
|
||||
Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)),
|
||||
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
|
||||
Array(1.0), 1, 2)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.tree.ClassificationLeafNode
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
import org.apache.spark.ml.tree.impl.TreeTests
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
|
||||
|
@ -71,8 +71,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
test("params") {
|
||||
ParamsSuite.checkParams(new RandomForestClassifier)
|
||||
val model = new RandomForestClassificationModel("rfc",
|
||||
Array(new DecisionTreeClassificationModel("dtc",
|
||||
new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
|
||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
|
|
|
@ -191,20 +191,6 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
|||
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
|
||||
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
|
||||
}
|
||||
|
||||
test("label/impurity stats") {
|
||||
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
|
||||
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
|
||||
val dtr = new DecisionTreeRegressor()
|
||||
.setImpurity("variance")
|
||||
.setMaxDepth(2)
|
||||
.setMaxBins(8)
|
||||
val model = dtr.fit(df)
|
||||
val statInfo = model.rootNode
|
||||
|
||||
assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0
|
||||
&& statInfo.getSumOfSquares == 600.0)
|
||||
}
|
||||
}
|
||||
|
||||
private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
|
||||
|
|
|
@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(topNode.stats.impurity > 0.0)
|
||||
|
||||
// set impurity and predict for child nodes
|
||||
assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0)
|
||||
assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0)
|
||||
assert(topNode.leftChild.get.toNode.prediction === 0.0)
|
||||
assert(topNode.rightChild.get.toNode.prediction === 1.0)
|
||||
assert(topNode.leftChild.get.stats.impurity === 0.0)
|
||||
assert(topNode.rightChild.get.stats.impurity === 0.0)
|
||||
}
|
||||
|
@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(topNode.stats.impurity > 0.0)
|
||||
|
||||
// set impurity and predict for child nodes
|
||||
assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0)
|
||||
assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0)
|
||||
assert(topNode.leftChild.get.toNode.prediction === 0.0)
|
||||
assert(topNode.rightChild.get.toNode.prediction === 1.0)
|
||||
assert(topNode.leftChild.get.stats.impurity === 0.0)
|
||||
assert(topNode.rightChild.get.stats.impurity === 0.0)
|
||||
}
|
||||
|
@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
left right
|
||||
*/
|
||||
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
|
||||
val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp)
|
||||
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
|
||||
|
||||
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
|
||||
val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp)
|
||||
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
|
||||
|
||||
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true)
|
||||
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
|
||||
val parentImp = parent.impurityStats
|
||||
|
||||
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
|
||||
val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp)
|
||||
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
|
||||
|
||||
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true)
|
||||
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
|
||||
val grandImp = grandParent.impurityStats
|
||||
|
||||
// Test feature importance computed at different subtrees.
|
||||
|
@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
// Forest consisting of (full tree) + (internal node with 2 leafs)
|
||||
val trees = Array(parent, grandParent).map { root =>
|
||||
new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode],
|
||||
numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel]
|
||||
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
|
||||
.asInstanceOf[DecisionTreeModel]
|
||||
}
|
||||
val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
|
||||
val tree2norm = feature0importance + feature1importance
|
||||
|
|
|
@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
* @param split Split for parent node
|
||||
* @return Parent node with children attached
|
||||
*/
|
||||
def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = {
|
||||
def buildParentNode(left: Node, right: Node, split: Split): Node = {
|
||||
val leftImp = left.impurityStats
|
||||
val rightImp = right.impurityStats
|
||||
val parentImp = leftImp.copy.add(rightImp)
|
||||
|
@ -168,15 +168,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
val gain = parentImp.calculate() -
|
||||
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
|
||||
val pred = parentImp.predict
|
||||
if (isClassification) {
|
||||
new ClassificationInternalNode(pred, parentImp.calculate(), gain,
|
||||
left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode],
|
||||
split, parentImp)
|
||||
} else {
|
||||
new RegressionInternalNode(pred, parentImp.calculate(), gain,
|
||||
left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode],
|
||||
split, parentImp)
|
||||
}
|
||||
new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -103,13 +103,6 @@ object MimaExcludes {
|
|||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"),
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="),
|
||||
|
||||
// [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes
|
||||
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"),
|
||||
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),
|
||||
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"),
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"),
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"),
|
||||
|
||||
// [SPARK-7132][ML] Add fit with validation set to spark.ml GBT
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
|
||||
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
|
||||
|
|
Loading…
Reference in a new issue