[SPARK-9963] [ML] RandomForest cleanup: replace predictNodeIndex with predictImpl
predictNodeIndex is moved to LearningNode and renamed predictImpl for consistency with Node.predictImpl Author: Luvsandondov Lkhamsuren <lkhamsurenl@gmail.com> Closes #8609 from lkhamsurenl/SPARK-9963.
This commit is contained in:
parent
e1e77b22b3
commit
cca2258685
|
@ -279,6 +279,43 @@ private[tree] class LearningNode(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the node index corresponding to this data point.
|
||||
* This function mimics prediction, passing an example from the root node down to a leaf
|
||||
* or unsplit node; that node's index is returned.
|
||||
*
|
||||
* @param binnedFeatures Binned feature vector for data point.
|
||||
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
|
||||
* @return Leaf index if the data point reaches a leaf.
|
||||
* Otherwise, last node reachable in tree matching this example.
|
||||
* Note: This is the global node index, i.e., the index used in the tree.
|
||||
* This index is different from the index used during training a particular
|
||||
* group of nodes on one call to [[findBestSplits()]].
|
||||
*/
|
||||
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
|
||||
if (this.isLeaf || this.split.isEmpty) {
|
||||
this.id
|
||||
} else {
|
||||
val split = this.split.get
|
||||
val featureIndex = split.featureIndex
|
||||
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
|
||||
if (this.leftChild.isEmpty) {
|
||||
// Not yet split. Return next layer of nodes to train
|
||||
if (splitLeft) {
|
||||
LearningNode.leftChildIndex(this.id)
|
||||
} else {
|
||||
LearningNode.rightChildIndex(this.id)
|
||||
}
|
||||
} else {
|
||||
if (splitLeft) {
|
||||
this.leftChild.get.predictImpl(binnedFeatures, splits)
|
||||
} else {
|
||||
this.rightChild.get.predictImpl(binnedFeatures, splits)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[tree] object LearningNode {
|
||||
|
|
|
@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the node index corresponding to this data point.
|
||||
* This function mimics prediction, passing an example from the root node down to a leaf
|
||||
* or unsplit node; that node's index is returned.
|
||||
*
|
||||
* @param node Node in tree from which to classify the given data point.
|
||||
* @param binnedFeatures Binned feature vector for data point.
|
||||
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
|
||||
* @return Leaf index if the data point reaches a leaf.
|
||||
* Otherwise, last node reachable in tree matching this example.
|
||||
* Note: This is the global node index, i.e., the index used in the tree.
|
||||
* This index is different from the index used during training a particular
|
||||
* group of nodes on one call to [[findBestSplits()]].
|
||||
*/
|
||||
private def predictNodeIndex(
|
||||
node: LearningNode,
|
||||
binnedFeatures: Array[Int],
|
||||
splits: Array[Array[Split]]): Int = {
|
||||
if (node.isLeaf || node.split.isEmpty) {
|
||||
node.id
|
||||
} else {
|
||||
val split = node.split.get
|
||||
val featureIndex = split.featureIndex
|
||||
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
|
||||
if (node.leftChild.isEmpty) {
|
||||
// Not yet split. Return index from next layer of nodes to train
|
||||
if (splitLeft) {
|
||||
LearningNode.leftChildIndex(node.id)
|
||||
} else {
|
||||
LearningNode.rightChildIndex(node.id)
|
||||
}
|
||||
} else {
|
||||
if (splitLeft) {
|
||||
predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
|
||||
} else {
|
||||
predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
|
||||
*
|
||||
|
@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging {
|
|||
agg: Array[DTStatsAggregator],
|
||||
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
|
||||
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
|
||||
val nodeIndex =
|
||||
predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
|
||||
val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
|
||||
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
|
||||
}
|
||||
agg
|
||||
|
|
Loading…
Reference in a new issue