[SPARK-9963] [ML] RandomForest cleanup: replace predictNodeIndex with predictImpl

predictNodeIndex is moved to LearningNode and renamed predictImpl for consistency with Node.predictImpl

Author: Luvsandondov Lkhamsuren <lkhamsurenl@gmail.com>

Closes #8609 from lkhamsurenl/SPARK-9963.
This commit is contained in:
Luvsandondov Lkhamsuren 2015-10-17 10:07:42 -07:00 committed by Joseph K. Bradley
parent e1e77b22b3
commit cca2258685
2 changed files with 38 additions and 43 deletions

View file

@ -279,6 +279,43 @@ private[tree] class LearningNode(
}
}
/**
* Get the node index corresponding to this data point.
* This function mimics prediction, passing an example from the root node down to a leaf
* or unsplit node; that node's index is returned.
*
* @param binnedFeatures Binned feature vector for data point.
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
* Note: This is the global node index, i.e., the index used in the tree.
* This index is different from the index used during training a particular
* group of nodes on one call to [[findBestSplits()]].
*/
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
if (this.isLeaf || this.split.isEmpty) {
this.id
} else {
val split = this.split.get
val featureIndex = split.featureIndex
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
if (this.leftChild.isEmpty) {
// Not yet split. Return next layer of nodes to train
if (splitLeft) {
LearningNode.leftChildIndex(this.id)
} else {
LearningNode.rightChildIndex(this.id)
}
} else {
if (splitLeft) {
this.leftChild.get.predictImpl(binnedFeatures, splits)
} else {
this.rightChild.get.predictImpl(binnedFeatures, splits)
}
}
}
}
}
private[tree] object LearningNode {

View file

@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging {
}
}
/**
* Get the node index corresponding to this data point.
* This function mimics prediction, passing an example from the root node down to a leaf
* or unsplit node; that node's index is returned.
*
* @param node Node in tree from which to classify the given data point.
* @param binnedFeatures Binned feature vector for data point.
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
* Note: This is the global node index, i.e., the index used in the tree.
* This index is different from the index used during training a particular
* group of nodes on one call to [[findBestSplits()]].
*/
private def predictNodeIndex(
node: LearningNode,
binnedFeatures: Array[Int],
splits: Array[Array[Split]]): Int = {
if (node.isLeaf || node.split.isEmpty) {
node.id
} else {
val split = node.split.get
val featureIndex = split.featureIndex
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
if (node.leftChild.isEmpty) {
// Not yet split. Return index from next layer of nodes to train
if (splitLeft) {
LearningNode.leftChildIndex(node.id)
} else {
LearningNode.rightChildIndex(node.id)
}
} else {
if (splitLeft) {
predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
} else {
predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
}
}
}
}
/**
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
*
@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging {
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex =
predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg