[SPARK-3366][MLLIB]Compute best splits distributively in decision tree

Currently, all best splits are computed on the driver, which makes the driver a bottleneck for both communication and computation. This PR fix this problem by computed best splits on executors.
Instead of send all aggregate stats to the driver node, we can send aggregate stats for a node to a particular executor, using `reduceByKey` operation, then we can compute best split for this node there.

Implementation details:

Each node now has a nodeStatsAggregator, which save aggregate stats for all features and bins.
First use mapPartition to compute node aggregate stats for all nodes in each partition.
Then transform node aggregate stats to (nodeIndex, nodeStatsAggregator) pairs and use to `reduceByKey` operation to combine nodeStatsAggregator for the same node.
After all stats have been combined, best splits can be computed for each node based on the node aggregate stats. Best split result is collected to driver to construct the decision tree.

CC: mengxr manishamde jkbradley, please help me review this, thanks.

Author: qiping.lqp <qiping.lqp@alibaba-inc.com>
Author: chouqin <liqiping1991@gmail.com>

Closes #2595 from chouqin/dt-dist-agg and squashes the following commits:

db0d24a [chouqin] fix a minor bug and adjust code
a0d9de3 [chouqin] adjust code based on comments
9f201a6 [chouqin] fix bug: statsSize -> allStatsSize
a8a7ed0 [chouqin] Merge branch 'master' of https://github.com/apache/spark into dt-dist-agg
f13b346 [chouqin] adjust randomforest comments
c32636e [chouqin] adjust code based on comments
ac6a505 [chouqin] adjust code based on comments
7bbb787 [chouqin] add comments
bdd2a63 [qiping.lqp] fix test suite
a75df27 [qiping.lqp] fix test suite
b5b0bc2 [qiping.lqp] fix style
e76414f [qiping.lqp] fix testsuite
748bd45 [qiping.lqp] fix type-mismatch bug
24eacd8 [qiping.lqp] fix type-mismatch bug
5f63d6c [qiping.lqp] add multiclassification using One-Vs-All strategy
4f56496 [qiping.lqp] fix bug
f00fc22 [qiping.lqp] fix bug
532993a [qiping.lqp] Compute best splits distributively in decision tree
This commit is contained in:
qiping.lqp 2014-10-03 03:26:17 -07:00 committed by Xiangrui Meng
parent 1c90347a4b
commit 2e4eae3a52
5 changed files with 190 additions and 275 deletions

View file

@ -23,7 +23,6 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.SparkContext._
/**
@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging {
* for each subset is updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging {
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
nodeIndex: Int,
bins: Array[Array[Bin]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging {
// Use all features
agg.metadata.numFeatures
}
val nodeOffset = agg.getNodeOffset(nodeIndex)
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
agg.getLeftRightFeatureOffsets(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
}
splitIndex += 1
@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging {
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
instanceWeight)
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging {
* For each feature, the sufficient statistics of one bin are updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
nodeIndex: Int,
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
val nodeOffset = agg.getNodeOffset(nodeIndex)
// Iterate over features.
if (featuresForNode.nonEmpty) {
// Use subsampled features
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.size) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
agg.update(featureIndexIdx, binIndex, label, instanceWeight)
featureIndexIdx += 1
}
} else {
@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
agg.update(featureIndex, binIndex, label, instanceWeight)
featureIndex += 1
}
}
@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging {
* @return agg
*/
def binSeqOp(
agg: DTStatsAggregator,
baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging {
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging {
agg
}
// Calculate bin aggregates.
timer.start("aggregation")
val binAggregates: DTStatsAggregator = {
val initAgg = if (metadata.subsamplingFeatures) {
new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
} else {
new DTStatsAggregatorFixedFeatures(metadata, numNodes)
/**
* Get node index in group --> features indices map,
* which is a short cut to find feature indices for a node given node index in group
* @param treeToNodeToIndexInfo
* @return
*/
def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
: Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
None
} else {
val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
assert(nodeIndexInfo.featureSubset.isDefined)
mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
}
}
input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
Some(mutableNodeToFeatures.toMap)
}
timer.stop("aggregation")
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
// In each partition, iterate all instances and compute aggregate stats for each node,
// yield an (nodeIndex, nodeAggregateStats) pair for each node.
// After a `reduceByKey` operation,
// stats of a node will be shuffled to a particular partition and be combined together,
// then best splits for nodes are found there.
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
val nodeToBestSplits =
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
new DTStatsAggregator(metadata, featuresForNode)
}
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOp(nodeStatsAggregators, _))
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
// find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) =
binsToBestSplit(aggStats, splits, featuresForNode)
(nodeIndex, (split, stats, predict))
}.collectAsMap()
timer.stop("chooseSplits")
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val (split: Split, stats: InformationGainStats, predict: Predict) =
binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
}
}
}
timer.stop("chooseSplits")
}
/**
@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging {
/**
* Find the best split for a node.
* @param binAggregates Bin statistics.
* @param nodeIndex Index into aggregates for node to split in this group.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
nodeIndex: Int,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
val metadata: DecisionTreeMetadata = binAggregates.metadata
// calculate predict only once
var predict: Option[Predict] = None
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val numSplits = metadata.numSplits(featureIndex)
if (metadata.isContinuous(featureIndex)) {
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
// Find best split.
@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (metadata.isUnordered(featureIndex)) {
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
val numBins = metadata.numBins(featureIndex)
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (metadata.isMulticlass) {
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging {
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging {
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =

View file

@ -171,8 +171,8 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput,
metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
timer.stop("findBestSplits")
}
@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging {
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
*
* treeToNodeToIndexInfo holds indices selected features for each node:
* treeIndex --> (global) node index --> (node index in group, feature indices).
* The (global) node index is the index in the tree; the node index in group is the

View file

@ -17,17 +17,19 @@
package org.apache.spark.mllib.tree.impl
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.impurity._
/**
* DecisionTree statistics aggregator.
* This holds a flat array of statistics for a set of (nodes, features, bins)
* DecisionTree statistics aggregator for a node.
* This holds a flat array of statistics for a set of (features, bins)
* and helps with indexing.
* This class is abstract to support learning with and without feature subsampling.
*/
private[tree] abstract class DTStatsAggregator(
val metadata: DecisionTreeMetadata) extends Serializable {
private[tree] class DTStatsAggregator(
val metadata: DecisionTreeMetadata,
featureSubset: Option[Array[Int]]) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator(
/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
val statsSize: Int = impurityAggregator.statsSize
private val statsSize: Int = impurityAggregator.statsSize
/**
* Number of bins for each feature. This is indexed by the feature index.
*/
private val numBins: Array[Int] = {
if (featureSubset.isDefined) {
featureSubset.get.map(metadata.numBins(_))
} else {
metadata.numBins
}
}
/**
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
/**
* Indicator for each feature of whether that feature is an unordered feature.
@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator(
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
/**
* Total number of elements stored in this aggregator.
* Total number of elements stored in this aggregator
*/
def allStatsSize: Int
private val allStatsSize: Int = featureOffsets.last
/**
* Get flat array of elements stored in this aggregator.
* Flat array of elements.
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features,
* the left child stats have binIndex in [0, numBins(featureIndex) / 2))
* and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
protected def allStats: Array[Double]
private val allStats: Array[Double] = new Array[Double](allStatsSize)
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
* [[getLeftRightFeatureOffsets]].
*/
def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
def update(
nodeIndex: Int,
featureIndex: Int,
binIndex: Int,
label: Double,
instanceWeight: Double): Unit = {
val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
* Pre-compute node offset for use with [[nodeUpdate]].
*/
def getNodeOffset(nodeIndex: Int): Int
/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
*/
def nodeUpdate(
nodeOffset: Int,
nodeIndex: Int,
featureIndex: Int,
binIndex: Int,
label: Double,
instanceWeight: Double): Unit
/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For ordered features only.
*/
def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For unordered features only.
*/
def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
(baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
}
/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin), using the given label.
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* Update the stats for a given (feature, bin), using the given label.
* @param featureOffset For ordered features, this is a pre-computed feature offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
* (feature, left/right child) offset from
* [[getLeftRightFeatureOffsets]].
*/
def nodeFeatureUpdate(
nodeFeatureOffset: Int,
def featureUpdate(
featureOffset: Int,
binIndex: Int,
label: Double,
instanceWeight: Double): Unit = {
impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
instanceWeight)
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
label, instanceWeight)
}
/**
* For a given (node, feature), merge the stats for two bins.
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
def getFeatureOffset(featureIndex: Int): Int = {
require(!isUnordered(featureIndex),
s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
s" for unordered feature $featureIndex.")
featureOffsets(featureIndex)
}
/**
* Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = featureOffsets(featureIndex)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
/**
* For a given feature, merge the stats for two bins.
* @param featureOffset For ordered features, this is a pre-computed feature offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
* (feature, left/right child) offset from
* [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
nodeFeatureOffset + otherBinIndex * statsSize)
def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
featureOffset + otherBinIndex * statsSize)
}
/**
@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator(
def merge(other: DTStatsAggregator): DTStatsAggregator = {
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator(
this
}
}
/**
* DecisionTree statistics aggregator.
* This holds a flat array of statistics for a set of (nodes, features, bins)
* and helps with indexing.
*
* This instance of [[DTStatsAggregator]] is used when not subsampling features.
*
* @param numNodes Number of nodes to collect statistics for.
*/
private[tree] class DTStatsAggregatorFixedFeatures(
metadata: DecisionTreeMetadata,
numNodes: Int) extends DTStatsAggregator(metadata) {
/**
* Offset for each feature for calculating indices into the [[allStats]] array.
* Mapping: featureIndex --> offset
*/
private val featureOffsets: Array[Int] = {
metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
/**
* Number of elements for each node, corresponding to stride between nodes in [[allStats]].
*/
private val nodeStride: Int = featureOffsets.last
override val allStatsSize: Int = numNodes * nodeStride
/**
* Flat array of elements.
* Index for start of stats for a (node, feature, bin) is:
* index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features, the left child stats precede the right child stats
* in the binIndex order.
*/
override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
override def nodeUpdate(
nodeOffset: Int,
nodeIndex: Int,
featureIndex: Int,
binIndex: Int,
label: Double,
instanceWeight: Double): Unit = {
val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
nodeIndex * nodeStride + featureOffsets(featureIndex)
}
}
/**
* DecisionTree statistics aggregator.
* This holds a flat array of statistics for a set of (nodes, features, bins)
* and helps with indexing.
*
* This instance of [[DTStatsAggregator]] is used when subsampling features.
*
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
*/
private[tree] class DTStatsAggregatorSubsampledFeatures(
metadata: DecisionTreeMetadata,
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
/**
* For each node, offset for each feature for calculating indices into the [[allStats]] array.
* Mapping: nodeIndex --> featureIndex --> offset
*/
private val featureOffsets: Array[Array[Int]] = {
val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
val offsets = new Array[Array[Int]](numNodes)
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
}
offsets
}
/**
* For each node, offset for each feature for calculating indices into the [[allStats]] array.
*/
protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
override val allStatsSize: Int = nodeOffsets.last
/**
* Flat array of elements.
* Index for start of stats for a (node, feature, bin) is:
* index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features, the left child stats precede the right child stats
* in the binIndex order.
*/
override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
* @param featureIndex Index of feature in featuresForNodes(nodeIndex).
* Note: This is NOT the original feature index.
*/
override def nodeUpdate(
nodeOffset: Int,
nodeIndex: Int,
featureIndex: Int,
binIndex: Int,
label: Double,
instanceWeight: Double): Unit = {
val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For ordered features only.
* @param featureIndex Index of feature in featuresForNodes(nodeIndex).
* Note: This is NOT the original feature index.
*/
override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
}
}
private[tree] object DTStatsAggregator extends Serializable {
/**
* Combines two aggregates (modifying the first) and returns the combination.
*/
def binCombOp(
agg1: DTStatsAggregator,
agg2: DTStatsAggregator): DTStatsAggregator = {
agg1.merge(agg2)
}
}

View file

@ -38,6 +38,17 @@ class InformationGainStats(
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
override def equals(o: Any) =
o match {
case other: InformationGainStats => {
gain == other.gain &&
impurity == other.impurity &&
leftImpurity == other.leftImpurity &&
rightImpurity == other.rightImpurity
}
case _ => false
}
}

View file

@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(nodesForGroup.size === numTrees, failString)
assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
if (numFeaturesPerNode == numFeatures) {
// featureSubset values should all be None
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),