[SPARK-25321][ML] Revert SPARK-14681 to avoid API breaking change

## What changes were proposed in this pull request?

This is the same as #22492 but for master branch. Revert SPARK-14681 to avoid API breaking changes.

cc: WeichenXu123

## How was this patch tested?

Existing unit tests.

Closes #22618 from mengxr/SPARK-25321.master.

Authored-by: WeichenXu <weichen.xu@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
WeichenXu 2018-10-07 10:06:44 -07:00 committed by Dongjoon Hyun
parent 669ade3a8e
commit ebd899b8a8
No known key found for this signature in database
GPG key ID: EDA00CE834F0FC5C
16 changed files with 108 additions and 333 deletions

View file

@ -168,7 +168,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
@Since("1.4.0")
class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")override val uid: String,
@Since("1.4.0")override val rootNode: ClassificationNode,
@Since("1.4.0")override val rootNode: Node,
@Since("1.6.0")override val numFeatures: Int,
@Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) =
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
override def predict(features: Vector): Double = {
@ -279,9 +279,8 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
val model = new DecisionTreeClassificationModel(metadata.uid,
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
metadata.getAndSetParams(model)
model
}
@ -296,10 +295,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true)
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
// Can't infer number of features from old model, so default to -1
new DecisionTreeClassificationModel(uid,
rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
}
}

View file

@ -412,14 +412,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override def load(path: String): GBTClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}

View file

@ -313,15 +313,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val trees: Array[DecisionTreeClassificationModel] = treesData.map {
case (treeMetadata, root) =>
val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
val tree =
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
treeMetadata.getAndSetParams(tree)
tree
}

View file

@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
@Since("1.4.0")
class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: RegressionNode,
override val rootNode: Node,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] (
* Construct a decision tree regression model.
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
private[ml] def this(rootNode: Node, numFeatures: Int) =
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
override def predict(features: Vector): Double = {
@ -279,9 +279,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false)
val model = new DecisionTreeRegressionModel(metadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
metadata.getAndSetParams(model)
model
}
@ -296,8 +295,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
require(oldModel.algo == OldAlgo.Regression,
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false)
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures)
new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
}
}

View file

@ -338,15 +338,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
override def load(path: String): GBTRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}

View file

@ -271,13 +271,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}

View file

@ -17,16 +17,14 @@
package org.apache.spark.ml.tree
import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats,
Node => OldNode, Predict => OldPredict}
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
/**
* Decision tree node interface.
*/
sealed trait Node extends Serializable {
sealed abstract class Node extends Serializable {
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
// code into the new API and deprecate the old API. SPARK-3727
@ -86,86 +84,35 @@ private[ml] object Node {
/**
* Create a new Node from the old Node format, recursively creating child nodes as needed.
*/
def fromOld(
oldNode: OldNode,
categoricalFeatures: Map[Int, Int],
isClassification: Boolean): Node = {
def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
if (isClassification) {
new ClassificationLeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else {
new RegressionLeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
}
new LeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
} else {
0.0
}
if (isClassification) {
new ClassificationInternalNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, gain = gain,
leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true)
.asInstanceOf[ClassificationNode],
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true)
.asInstanceOf[ClassificationNode],
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
} else {
new RegressionInternalNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, gain = gain,
leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false)
.asInstanceOf[RegressionNode],
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false)
.asInstanceOf[RegressionNode],
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}
@Since("2.4.0")
sealed trait ClassificationNode extends Node {
/**
* Get count of training examples for specified label in this node
* @param label label number in the range [0, numClasses)
*/
@Since("2.4.0")
def getLabelCount(label: Int): Double = {
require(label >= 0 && label < impurityStats.stats.length,
"label should be in the range between 0 (inclusive) " +
s"and ${impurityStats.stats.length} (exclusive).")
impurityStats.stats(label)
}
}
@Since("2.4.0")
sealed trait RegressionNode extends Node {
/** Number of training data points in this node */
@Since("2.4.0")
def getCount: Double = impurityStats.stats(0)
/** Sum over training data points of the labels in this node */
@Since("2.4.0")
def getSum: Double = impurityStats.stats(1)
/** Sum over training data points of the square of the labels in this node */
@Since("2.4.0")
def getSumOfSquares: Double = impurityStats.stats(2)
}
@Since("2.4.0")
sealed trait LeafNode extends Node {
/** Prediction this node makes. */
def prediction: Double
def impurity: Double
/**
* Decision tree leaf node.
* @param prediction Prediction this node makes
* @param impurity Impurity measure at this node (for training data)
*/
class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
@ -188,58 +135,32 @@ sealed trait LeafNode extends Node {
override private[ml] def maxSplitFeatureIndex(): Int = -1
}
/**
* Decision tree leaf node for classification.
*/
@Since("2.4.0")
class ClassificationLeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override private[ml] val impurityStats: ImpurityCalculator)
extends ClassificationNode with LeafNode {
override private[tree] def deepCopy(): Node = {
new ClassificationLeafNode(prediction, impurity, impurityStats)
}
}
/**
* Decision tree leaf node for regression.
*/
@Since("2.4.0")
class RegressionLeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override private[ml] val impurityStats: ImpurityCalculator)
extends RegressionNode with LeafNode {
override private[tree] def deepCopy(): Node = {
new RegressionLeafNode(prediction, impurity, impurityStats)
new LeafNode(prediction, impurity, impurityStats)
}
}
/**
* Internal Decision Tree node.
* @param prediction Prediction this node would make if it were a leaf node
* @param impurity Impurity measure at this node (for training data)
* @param gain Information gain value. Values less than 0 indicate missing values;
* this quirk will be removed with future updates.
* @param leftChild Left-hand child node
* @param rightChild Right-hand child node
* @param split Information about the test used to split to the left or right child.
*/
@Since("2.4.0")
sealed trait InternalNode extends Node {
class InternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
val gain: Double,
val leftChild: Node,
val rightChild: Node,
val split: Split,
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
/**
* Information gain value. Values less than 0 indicate missing values;
* this quirk will be removed with future updates.
*/
def gain: Double
/** Left-hand child node */
def leftChild: Node
/** Right-hand child node */
def rightChild: Node
/** Information about the test used to split to the left or right child. */
def split: Split
// Note to developers: The constructor argument impurityStats should be reconsidered before we
// make the constructor public. We may be able to improve the representation.
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
@ -284,6 +205,11 @@ sealed trait InternalNode extends Node {
math.max(split.featureIndex,
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
}
override private[tree] def deepCopy(): Node = {
new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(),
split, impurityStats)
}
}
private object InternalNode {
@ -314,57 +240,6 @@ private object InternalNode {
}
}
/**
* Internal Decision Tree node for regression.
*/
@Since("2.4.0")
class ClassificationInternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override val gain: Double,
override val leftChild: ClassificationNode,
override val rightChild: ClassificationNode,
override val split: Split,
override private[ml] val impurityStats: ImpurityCalculator)
extends ClassificationNode with InternalNode {
// Note to developers: The constructor argument impurityStats should be reconsidered before we
// make the constructor public. We may be able to improve the representation.
override private[tree] def deepCopy(): Node = {
new ClassificationInternalNode(prediction, impurity, gain,
leftChild.deepCopy().asInstanceOf[ClassificationNode],
rightChild.deepCopy().asInstanceOf[ClassificationNode],
split, impurityStats)
}
}
/**
* Internal Decision Tree node for regression.
*/
@Since("2.4.0")
class RegressionInternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override val gain: Double,
override val leftChild: RegressionNode,
override val rightChild: RegressionNode,
override val split: Split,
override private[ml] val impurityStats: ImpurityCalculator)
extends RegressionNode with InternalNode {
// Note to developers: The constructor argument impurityStats should be reconsidered before we
// make the constructor public. We may be able to improve the representation.
override private[tree] def deepCopy(): Node = {
new RegressionInternalNode(prediction, impurity, gain,
leftChild.deepCopy().asInstanceOf[RegressionNode],
rightChild.deepCopy().asInstanceOf[RegressionNode],
split, impurityStats)
}
}
/**
* Version of a node used in learning. This uses vars so that we can modify nodes as we split the
* tree by adding children, etc.
@ -390,52 +265,30 @@ private[tree] class LearningNode(
var isLeaf: Boolean,
var stats: ImpurityStats) extends Serializable {
def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true)
def toClassificationNode(prune: Boolean = true): ClassificationNode = {
toNode(true, prune).asInstanceOf[ClassificationNode]
}
def toRegressionNode(prune: Boolean = true): RegressionNode = {
toNode(false, prune).asInstanceOf[RegressionNode]
}
def toNode: Node = toNode(prune = true)
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode(isClassification: Boolean, prune: Boolean): Node = {
def toNode(prune: Boolean = true): Node = {
if (!leftChild.isEmpty || !rightChild.isEmpty) {
assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
(leftChild.get.toNode(isClassification, prune),
rightChild.get.toNode(isClassification, prune)) match {
(leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
if (isClassification) {
new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
} else {
new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
}
new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
case (l, r) =>
if (isClassification) {
new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity,
stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode],
split.get, stats.impurityCalculator)
} else {
new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode],
split.get, stats.impurityCalculator)
}
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
l, r, split.get, stats.impurityCalculator)
}
} else {
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
val impurity = if (stats.valid) stats.impurity else -1.0
if (isClassification) {
new ClassificationLeafNode(stats.impurityCalculator.predict, impurity,
if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
stats.impurityCalculator)
} else {
new RegressionLeafNode(stats.impurityCalculator.predict, impurity,
stats.impurityCalculator)
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}
}
}

View file

@ -226,23 +226,23 @@ private[spark] object RandomForest extends Logging with Serializable {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune),
numFeatures, strategy.getNumClasses)
new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map { rootNode =>
new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures)
new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
}
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures,
new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map(rootNode =>
new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures))
new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
}
}
}

View file

@ -219,10 +219,8 @@ private[ml] object TreeEnsembleModel {
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
case _: LeafNode =>
case n: LeafNode =>
// do nothing
case _ =>
throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}")
}
}
@ -319,8 +317,6 @@ private[ml] object DecisionTreeModelReadWrite {
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
-1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
id)
case _ =>
throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}")
}
}
@ -331,7 +327,7 @@ private[ml] object DecisionTreeModelReadWrite {
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
sparkSession: SparkSession, isClassification: Boolean): Node = {
sparkSession: SparkSession): Node = {
import sparkSession.implicits._
implicit val format = DefaultFormats
@ -343,7 +339,7 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).as[NodeData]
buildTreeFromNodes(data.collect(), impurityType, isClassification)
buildTreeFromNodes(data.collect(), impurityType)
}
/**
@ -352,8 +348,7 @@ private[ml] object DecisionTreeModelReadWrite {
* @param impurityType Impurity type for this tree
* @return Root node of reconstructed tree
*/
def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
isClassification: Boolean): Node = {
def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
// Load all nodes, sorted by ID.
val nodes = data.sortBy(_.id)
// Sanity checks; could remove
@ -369,21 +364,10 @@ private[ml] object DecisionTreeModelReadWrite {
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
if (isClassification) {
new ClassificationInternalNode(n.prediction, n.impurity, n.gain,
leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode],
n.split.getSplit, impurityStats)
} else {
new RegressionInternalNode(n.prediction, n.impurity, n.gain,
leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode],
n.split.getSplit, impurityStats)
}
new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild,
n.split.getSplit, impurityStats)
} else {
if (isClassification) {
new ClassificationLeafNode(n.prediction, n.impurity, impurityStats)
} else {
new RegressionLeafNode(n.prediction, n.impurity, impurityStats)
}
new LeafNode(n.prediction, n.impurity, impurityStats)
}
finalNodes(n.id) = node
}
@ -437,8 +421,7 @@ private[ml] object EnsembleModelReadWrite {
path: String,
sql: SparkSession,
className: String,
treeClassName: String,
isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
import sql.implicits._
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
@ -466,8 +449,7 @@ private[ml] object EnsembleModelReadWrite {
val rootNodesRDD: RDD[(Int, Node)] =
nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
case (treeID: Int, nodeData: Iterable[NodeData]) =>
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(
nodeData.toArray, impurityType, isClassification)
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
(metadata, treesMetadata.zip(rootNodes), treesWeights)

View file

@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.ClassificationLeafNode
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@ -61,8 +61,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
val model = new DecisionTreeClassificationModel("dtc",
new ClassificationLeafNode(0.0, 0.0, null), 1, 2)
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)
ParamsSuite.checkParams(model)
}
@ -376,32 +375,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
testDefaultReadWrite(model)
}
test("label/impurity stats") {
val arr = Array(
LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val rdd = sc.parallelize(arr)
val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2)
val dt1 = new DecisionTreeClassifier()
.setImpurity("entropy")
.setMaxDepth(2)
.setMinInstancesPerNode(2)
val model1 = dt1.fit(df)
val rootNode1 = model1.rootNode
assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0))
val dt2 = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxDepth(2)
.setMinInstancesPerNode(2)
val model2 = dt2.fit(df)
val rootNode2 = model2.rootNode
assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0))
}
}
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {

View file

@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.RegressionLeafNode
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
@ -70,7 +70,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)),
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
Array(1.0), 1, 2)
ParamsSuite.checkParams(model)
}

View file

@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.ClassificationLeafNode
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@ -71,8 +71,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc",
new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
ParamsSuite.checkParams(model)
}

View file

@ -191,20 +191,6 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
test("label/impurity stats") {
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
val dtr = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(8)
val model = dtr.fit(df)
val statInfo = model.rootNode
assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0
&& statInfo.getSumOfSquares == 600.0)
}
}
private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {

View file

@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.stats.impurity > 0.0)
// set impurity and predict for child nodes
assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0)
assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0)
assert(topNode.leftChild.get.toNode.prediction === 0.0)
assert(topNode.rightChild.get.toNode.prediction === 1.0)
assert(topNode.leftChild.get.stats.impurity === 0.0)
assert(topNode.rightChild.get.stats.impurity === 0.0)
}
@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.stats.impurity > 0.0)
// set impurity and predict for child nodes
assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0)
assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0)
assert(topNode.leftChild.get.toNode.prediction === 0.0)
assert(topNode.rightChild.get.toNode.prediction === 1.0)
assert(topNode.leftChild.get.stats.impurity === 0.0)
assert(topNode.rightChild.get.stats.impurity === 0.0)
}
@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
left right
*/
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp)
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp)
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
val parentImp = parent.impurityStats
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp)
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true)
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
val grandImp = grandParent.impurityStats
// Test feature importance computed at different subtrees.
@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Forest consisting of (full tree) + (internal node with 2 leafs)
val trees = Array(parent, grandParent).map { root =>
new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode],
numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel]
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
.asInstanceOf[DecisionTreeModel]
}
val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance

View file

@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* @param split Split for parent node
* @return Parent node with children attached
*/
def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = {
def buildParentNode(left: Node, right: Node, split: Split): Node = {
val leftImp = left.impurityStats
val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp)
@ -168,15 +168,7 @@ private[ml] object TreeTests extends SparkFunSuite {
val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict
if (isClassification) {
new ClassificationInternalNode(pred, parentImp.calculate(), gain,
left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode],
split, parentImp)
} else {
new RegressionInternalNode(pred, parentImp.calculate(), gain,
left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode],
split, parentImp)
}
new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
}
/**

View file

@ -103,13 +103,6 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="),
// [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"),
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"),
// [SPARK-7132][ML] Add fit with validation set to spark.ml GBT
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),