[SPARK-10524][ML] Use the soft prediction to order categories' bins

JIRA: https://issues.apache.org/jira/browse/SPARK-10524

Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction.

Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8734 from viirya/dt-soft-centroids.
This commit is contained in:
Liang-Chi Hsieh 2016-02-09 17:10:55 -08:00 committed by Joseph K. Bradley
parent 0e5ebac3c1
commit 9267bc68fa
4 changed files with 203 additions and 142 deletions

View file

@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
* @param binAggregates Bin statistics. * @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node) * @return tuple for best split: (Split, information gain, prediction at node)
*/ */
private def binsToBestSplit( private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator, binAggregates: DTStatsAggregator,
splits: Array[Array[Split]], splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]], featuresForNode: Option[Array[Int]],
@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging {
* *
* centroidForCategories is a list: (category, centroid) * centroidForCategories is a list: (category, centroid)
*/ */
val centroidForCategories = if (binAggregates.metadata.isMulticlass) { val centroidForCategories = Range(0, numCategories).map { case featureValue =>
// For categorical variables in multiclass classification, val categoryStats =
// the bins are ordered by the impurity of their corresponding labels. binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
Range(0, numCategories).map { case featureValue => val centroid = if (categoryStats.count != 0) {
val categoryStats = if (binAggregates.metadata.isMulticlass) {
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) // multiclass classification
val centroid = if (categoryStats.count != 0) { // For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
categoryStats.calculate() categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// binary classification
// For categorical variables in binary classification,
// the bins are ordered by the count of class 1.
categoryStats.stats(1)
} else { } else {
Double.MaxValue // regression
} // For categorical variables in regression and binary classification,
(featureValue, centroid) // the bins are ordered by the prediction.
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict categoryStats.predict
} else {
Double.MaxValue
} }
(featureValue, centroid) } else {
Double.MaxValue
} }
(featureValue, centroid)
} }
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

View file

@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
* @param binAggregates Bin statistics. * @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node) * @return tuple for best split: (Split, information gain, prediction at node)
*/ */
private def binsToBestSplit( private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator, binAggregates: DTStatsAggregator,
splits: Array[Array[Split]], splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]], featuresForNode: Option[Array[Int]],
@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging {
// For each (feature, split), calculate the gain, and select the best (feature, split). // For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) = val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) { val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx) featuresForNode.get.apply(featureIndexIdx)
} else { } else {
featureIndexIdx featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
} }
// Find best split. val numSplits = binAggregates.metadata.numSplits(featureIndex)
val (bestFeatureSplitIndex, bestFeatureGainStats) = if (binAggregates.metadata.isContinuous(featureIndex)) {
Range(0, numSplits).map { case splitIdx => // Cumulative sum (scanLeft) of bin statistics.
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) // Afterwards, binAggregates for a bin is the sum of aggregates for
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) // that bin + all preceding bins.
rightChildStats.subtract(leftChildStats) val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
predictWithImpurity = Some(predictWithImpurity.getOrElse( var splitIndex = 0
calculatePredictImpurity(leftChildStats, rightChildStats))) while (splitIndex < numSplits) {
val gainStats = calculateGainForSplit(leftChildStats, binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) splitIndex += 1
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
} }
} else { // regression or binary classification // Find best split.
// For categorical variables in regression and binary classification, val (bestFeatureSplitIndex, bestFeatureGainStats) =
// the bins are ordered by the centroid of their corresponding labels. Range(0, numSplits).map { case splitIdx =>
Range(0, numBins).map { case featureValue => val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val rightChildStats =
val centroid = if (categoryStats.count != 0) { binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
categoryStats.predict rightChildStats.subtract(leftChildStats)
} else { predictWithImpurity = Some(predictWithImpurity.getOrElse(
Double.MaxValue calculatePredictImpurity(leftChildStats, rightChildStats)))
} val gainStats = calculateGainForSplit(leftChildStats,
(featureValue, centroid) rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
} (splitIdx, gainStats)
} }.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) /* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
// bins sorted by centroids * splits are considered. (With K categories, we consider K - 1 possible splits.)
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) *
* centroidForCategories is a list: (category, centroid)
logDebug("Sorted centroids for categorical variable = " + */
categoriesSortedByCentroid.mkString(",")) val centroidForCategories = Range(0, numBins).map { case featureValue =>
val categoryStats =
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats = val centroid = if (categoryStats.count != 0) {
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) if (binAggregates.metadata.isMulticlass) {
rightChildStats.subtract(leftChildStats) // For categorical variables in multiclass classification,
predictWithImpurity = Some(predictWithImpurity.getOrElse( // the bins are ordered by the impurity of their corresponding labels.
calculatePredictImpurity(leftChildStats, rightChildStats))) categoryStats.calculate()
val gainStats = calculateGainForSplit(leftChildStats, } else if (binAggregates.metadata.isClassification) {
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) // For categorical variables in binary classification,
(splitIndex, gainStats) // the bins are ordered by the count of class 1.
}.maxBy(_._2.gain) categoryStats.stats(1)
val categoriesForSplit = } else {
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) // For categorical variables in regression,
val bestFeatureSplit = // the bins are ordered by the prediction.
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) categoryStats.predict
(bestFeatureSplit, bestFeatureGainStats) }
} } else {
Double.MaxValue
}
(featureValue, centroid)
}
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain) }.maxBy(_._2.gain)
(bestSplit, bestSplitStats, predictWithImpurity.get._1) (bestSplit, bestSplitStats, predictWithImpurity.get._1)

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val model = dt.fit(df) val model = dt.fit(df)
} }
test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxDepth(1)
.setMaxBins(3)
val model = dt.fit(df)
model.rootNode match {
case n: InternalNode =>
n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
}
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load // Tests of model save/load
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////

View file

@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.rightNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0)
} }
test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val input = sc.parallelize(arr)
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
val model = new DecisionTree(strategy).run(input)
model.topNode.split.get match {
case Split(_, _, _, categories: List[Double]) =>
assert(categories === List(1.0))
}
}
test("Second level node building with vs. without groups") { test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints() val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000) assert(arr.length === 1000)