[SPARK-32298][ML] tree models prediction optimization
### What changes were proposed in this pull request? use while-loop instead of the recursive way ### Why are the changes needed? 3% ~ 10% faster ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #29095 from zhengruifeng/tree_pred_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
5daf244d0f
commit
3a60b41949
|
@ -174,22 +174,32 @@ class InternalNode private[ml] (
|
|||
}
|
||||
|
||||
override private[ml] def predictImpl(features: Vector): LeafNode = {
|
||||
if (split.shouldGoLeft(features)) {
|
||||
leftChild.predictImpl(features)
|
||||
} else {
|
||||
rightChild.predictImpl(features)
|
||||
var node: Node = this
|
||||
while (node.isInstanceOf[InternalNode]) {
|
||||
val n = node.asInstanceOf[InternalNode]
|
||||
if (n.split.shouldGoLeft(features)) {
|
||||
node = n.leftChild
|
||||
} else {
|
||||
node = n.rightChild
|
||||
}
|
||||
}
|
||||
node.asInstanceOf[LeafNode]
|
||||
}
|
||||
|
||||
override private[ml] def predictBinned(
|
||||
binned: Array[Int],
|
||||
splits: Array[Array[Split]]): LeafNode = {
|
||||
val i = split.featureIndex
|
||||
if (split.shouldGoLeft(binned(i), splits(i))) {
|
||||
leftChild.predictBinned(binned, splits)
|
||||
} else {
|
||||
rightChild.predictBinned(binned, splits)
|
||||
var node: Node = this
|
||||
while (node.isInstanceOf[InternalNode]) {
|
||||
val n = node.asInstanceOf[InternalNode]
|
||||
val i = n.split.featureIndex
|
||||
if (n.split.shouldGoLeft(binned(i), splits(i))) {
|
||||
node = n.leftChild
|
||||
} else {
|
||||
node = n.rightChild
|
||||
}
|
||||
}
|
||||
node.asInstanceOf[LeafNode]
|
||||
}
|
||||
|
||||
override private[tree] def numDescendants: Int = {
|
||||
|
@ -326,27 +336,27 @@ private[tree] class LearningNode(
|
|||
* [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]].
|
||||
*/
|
||||
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
|
||||
if (this.isLeaf || this.split.isEmpty) {
|
||||
this.id
|
||||
} else {
|
||||
val split = this.split.get
|
||||
var node = this
|
||||
while (!node.isLeaf && node.split.nonEmpty) {
|
||||
val split = node.split.get
|
||||
val featureIndex = split.featureIndex
|
||||
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
|
||||
if (this.leftChild.isEmpty) {
|
||||
if (node.leftChild.isEmpty) {
|
||||
// Not yet split. Return next layer of nodes to train
|
||||
if (splitLeft) {
|
||||
LearningNode.leftChildIndex(this.id)
|
||||
return LearningNode.leftChildIndex(node.id)
|
||||
} else {
|
||||
LearningNode.rightChildIndex(this.id)
|
||||
return LearningNode.rightChildIndex(node.id)
|
||||
}
|
||||
} else {
|
||||
if (splitLeft) {
|
||||
this.leftChild.get.predictImpl(binnedFeatures, splits)
|
||||
node = node.leftChild.get
|
||||
} else {
|
||||
this.rightChild.get.predictImpl(binnedFeatures, splits)
|
||||
node = node.rightChild.get
|
||||
}
|
||||
}
|
||||
}
|
||||
node.id
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue