[SPARK-32298][ML] tree models prediction optimization

### What changes were proposed in this pull request?
use while-loop instead of the recursive way

### Why are the changes needed?
3% ~ 10% faster

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing testsuites

Closes #29095 from zhengruifeng/tree_pred_opt.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
zhengruifeng 2020-07-17 12:00:49 -05:00 committed by Sean Owen
parent 5daf244d0f
commit 3a60b41949

View file

@ -174,22 +174,32 @@ class InternalNode private[ml] (
}
override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) {
leftChild.predictImpl(features)
} else {
rightChild.predictImpl(features)
var node: Node = this
while (node.isInstanceOf[InternalNode]) {
val n = node.asInstanceOf[InternalNode]
if (n.split.shouldGoLeft(features)) {
node = n.leftChild
} else {
node = n.rightChild
}
}
node.asInstanceOf[LeafNode]
}
override private[ml] def predictBinned(
binned: Array[Int],
splits: Array[Array[Split]]): LeafNode = {
val i = split.featureIndex
if (split.shouldGoLeft(binned(i), splits(i))) {
leftChild.predictBinned(binned, splits)
} else {
rightChild.predictBinned(binned, splits)
var node: Node = this
while (node.isInstanceOf[InternalNode]) {
val n = node.asInstanceOf[InternalNode]
val i = n.split.featureIndex
if (n.split.shouldGoLeft(binned(i), splits(i))) {
node = n.leftChild
} else {
node = n.rightChild
}
}
node.asInstanceOf[LeafNode]
}
override private[tree] def numDescendants: Int = {
@ -326,27 +336,27 @@ private[tree] class LearningNode(
* [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]].
*/
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
if (this.isLeaf || this.split.isEmpty) {
this.id
} else {
val split = this.split.get
var node = this
while (!node.isLeaf && node.split.nonEmpty) {
val split = node.split.get
val featureIndex = split.featureIndex
val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
if (this.leftChild.isEmpty) {
if (node.leftChild.isEmpty) {
// Not yet split. Return next layer of nodes to train
if (splitLeft) {
LearningNode.leftChildIndex(this.id)
return LearningNode.leftChildIndex(node.id)
} else {
LearningNode.rightChildIndex(this.id)
return LearningNode.rightChildIndex(node.id)
}
} else {
if (splitLeft) {
this.leftChild.get.predictImpl(binnedFeatures, splits)
node = node.leftChild.get
} else {
this.rightChild.get.predictImpl(binnedFeatures, splits)
node = node.rightChild.get
}
}
}
node.id
}
}