[SPARK-10788][MLLIB][ML] Remove duplicate bins for decision trees

Decision trees in spark.ml (RandomForest.scala) communicate twice as much data as needed for unordered categorical features. Here's an example.

Say there are 3 categories A, B, C. We consider 3 splits:

* A vs. B, C
* A, B vs. C
* A, C vs. B

Currently, we collect statistics for each of the 6 subsets of categories (3 * 2 = 6). However, we could instead collect statistics for the 3 subsets on the left-hand side of the 3 possible splits: A and A,B and A,C. If we also have stats for the entire node, then we can compute the stats for the 3 subsets on the right-hand side of the splits. In pseudomath: stats(B,C) = stats(A,B,C) - stats(A).

This patch adds a parent stats array to the `DTStatsAggregator` so that the right child stats do not need to be stored. The right child stats are computed by subtracting left child stats from the parent stats for unordered categorical features.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #9474 from sethah/SPARK-10788.
This commit is contained in:
sethah 2016-03-17 16:44:41 -07:00 committed by Joseph K. Bradley
parent b39e80d39d
commit 1614485fd9
9 changed files with 54 additions and 49 deletions

View file

@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging {
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
agg.getLeftRightFeatureOffsets(featureIndexIdx)
val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
val featureSplits = splits(featureIndex)
@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging {
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
} else {
agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
}
splitIndex += 1
}
@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging {
// Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
var gainAndImpurityStats: ImpurityStats = if (level ==0) {
var gainAndImpurityStats: ImpurityStats = if (level == 0) {
null
} else {
node.stats
@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
val rightChildStats = binAggregates.getParentImpurityCalculator()
.subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)

View file

@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
/**
* Method to train a decision tree model over an RDD
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return DecisionTreeModel that can be used for prediction.
*/
@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging {
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
agg.getLeftRightFeatureOffsets(featureIndexIdx)
val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging {
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
}
splitIndex += 1
}
@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
val rightChildStats = binAggregates.getParentImpurityCalculator()
.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,

View file

@ -73,25 +73,33 @@ private[spark] class DTStatsAggregator(
* Flat array of elements.
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features,
* the left child stats have binIndex in [0, numBins(featureIndex) / 2))
* and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
private val allStats: Array[Double] = new Array[Double](allStatsSize)
/**
* Array of parent node sufficient stats.
*
* Note: this is necessary because stats for the parent node are not available
* on the first iteration of tree learning.
*/
private val parentStats: Array[Double] = new Array[Double](statsSize)
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
* @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightFeatureOffsets]].
*/
def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
* Get an [[ImpurityCalculator]] for the parent node.
*/
def getParentImpurityCalculator(): ImpurityCalculator = {
impurityAggregator.getCalculator(parentStats, 0)
}
/**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
@ -100,14 +108,18 @@ private[spark] class DTStatsAggregator(
impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
* Update the parent node stats using the given label.
*/
def updateParent(label: Double, instanceWeight: Double): Unit = {
impurityAggregator.update(parentStats, 0, label, instanceWeight)
}
/**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
* @param featureOffset For ordered features, this is a pre-computed feature offset
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (feature, left/right child) offset from
* [[getLeftRightFeatureOffsets]].
*/
def featureUpdate(
featureOffset: Int,
@ -124,22 +136,10 @@ private[spark] class DTStatsAggregator(
*/
def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
* Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
val baseOffset = featureOffsets(featureIndex)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
/**
* For a given feature, merge the stats for two bins.
* @param featureOffset For ordered features, this is a pre-computed feature offset
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (feature, left/right child) offset from
* [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator(
allStats(i) += other.allStats(i)
i += 1
}
require(statsSize == other.statsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
s"stats vectors. This aggregator's parent stats are length $statsSize, " +
s"but the other is ${other.statsSize}.")
var j = 0
while (j < statsSize) {
parentStats(j) += other.parentStats(j)
j += 1
}
this
}
}

View file

@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata(
/**
* Number of splits for the given feature.
* For unordered features, there are 2 bins per split.
* For unordered features, there is 1 bin per split.
* For ordered features, there is 1 more bin than split.
*/
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
numBins(featureIndex) >> 1
numBins(featureIndex)
} else {
numBins(featureIndex) - 1
}
@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging {
* there are math.pow(2, arity - 1) - 1 such splits.
* Each split has 2 corresponding bins.
*/
def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
}

View file

@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int)
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
}
}
/**

View file

@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int)
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
}
}
/**

View file

@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
}
/**

View file

@ -93,7 +93,6 @@ private[tree] class VarianceAggregator()
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
}
}
/**

View file

@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(bins.length === 2)
assert(splits(0).length === 3)
assert(bins(0).length === 0)
assert(metadata.numSplits(0) === 3)
assert(metadata.numBins(0) === 3)
assert(metadata.numSplits(1) === 3)
assert(metadata.numBins(1) === 3)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)