[SPARK-9963] [ML] RandomForest cleanup: replace predictNodeIndex with predictImpl
predictNodeIndex is moved to LearningNode and renamed predictImpl for consistency with Node.predictImpl Author: Luvsandondov Lkhamsuren <lkhamsurenl@gmail.com> Closes #8609 from lkhamsurenl/SPARK-9963.
This commit is contained in:
parent
e1e77b22b3
commit
cca2258685
|
@ -279,6 +279,43 @@ private[tree] class LearningNode(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the node index corresponding to this data point.
|
||||||
|
* This function mimics prediction, passing an example from the root node down to a leaf
|
||||||
|
* or unsplit node; that node's index is returned.
|
||||||
|
*
|
||||||
|
* @param binnedFeatures Binned feature vector for data point.
|
||||||
|
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
|
||||||
|
* @return Leaf index if the data point reaches a leaf.
|
||||||
|
* Otherwise, last node reachable in tree matching this example.
|
||||||
|
* Note: This is the global node index, i.e., the index used in the tree.
|
||||||
|
* This index is different from the index used during training a particular
|
||||||
|
* group of nodes on one call to [[findBestSplits()]].
|
||||||
|
*/
|
||||||
|
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
|
||||||
|
if (this.isLeaf || this.split.isEmpty) {
|
||||||
|
this.id
|
||||||
|
} else {
|
||||||
|
val split = this.split.get
|
||||||
|
val featureIndex = split.featureIndex
|
||||||
|
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
|
||||||
|
if (this.leftChild.isEmpty) {
|
||||||
|
// Not yet split. Return next layer of nodes to train
|
||||||
|
if (splitLeft) {
|
||||||
|
LearningNode.leftChildIndex(this.id)
|
||||||
|
} else {
|
||||||
|
LearningNode.rightChildIndex(this.id)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (splitLeft) {
|
||||||
|
this.leftChild.get.predictImpl(binnedFeatures, splits)
|
||||||
|
} else {
|
||||||
|
this.rightChild.get.predictImpl(binnedFeatures, splits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private[tree] object LearningNode {
|
private[tree] object LearningNode {
|
||||||
|
|
|
@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the node index corresponding to this data point.
|
|
||||||
* This function mimics prediction, passing an example from the root node down to a leaf
|
|
||||||
* or unsplit node; that node's index is returned.
|
|
||||||
*
|
|
||||||
* @param node Node in tree from which to classify the given data point.
|
|
||||||
* @param binnedFeatures Binned feature vector for data point.
|
|
||||||
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
|
|
||||||
* @return Leaf index if the data point reaches a leaf.
|
|
||||||
* Otherwise, last node reachable in tree matching this example.
|
|
||||||
* Note: This is the global node index, i.e., the index used in the tree.
|
|
||||||
* This index is different from the index used during training a particular
|
|
||||||
* group of nodes on one call to [[findBestSplits()]].
|
|
||||||
*/
|
|
||||||
private def predictNodeIndex(
|
|
||||||
node: LearningNode,
|
|
||||||
binnedFeatures: Array[Int],
|
|
||||||
splits: Array[Array[Split]]): Int = {
|
|
||||||
if (node.isLeaf || node.split.isEmpty) {
|
|
||||||
node.id
|
|
||||||
} else {
|
|
||||||
val split = node.split.get
|
|
||||||
val featureIndex = split.featureIndex
|
|
||||||
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
|
|
||||||
if (node.leftChild.isEmpty) {
|
|
||||||
// Not yet split. Return index from next layer of nodes to train
|
|
||||||
if (splitLeft) {
|
|
||||||
LearningNode.leftChildIndex(node.id)
|
|
||||||
} else {
|
|
||||||
LearningNode.rightChildIndex(node.id)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (splitLeft) {
|
|
||||||
predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
|
|
||||||
} else {
|
|
||||||
predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
|
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
|
||||||
*
|
*
|
||||||
|
@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging {
|
||||||
agg: Array[DTStatsAggregator],
|
agg: Array[DTStatsAggregator],
|
||||||
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
|
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
|
||||||
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
|
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
|
||||||
val nodeIndex =
|
val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
|
||||||
predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
|
|
||||||
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
|
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
|
||||||
}
|
}
|
||||||
agg
|
agg
|
||||||
|
|
Loading…
Reference in a new issue