[SPARK-10524][ML] Use the soft prediction to order categories' bins
JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8734 from viirya/dt-soft-centroids.
This commit is contained in:
parent
0e5ebac3c1
commit
9267bc68fa
|
@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
|
|||
* @param binAggregates Bin statistics.
|
||||
* @return tuple for best split: (Split, information gain, prediction at node)
|
||||
*/
|
||||
private def binsToBestSplit(
|
||||
private[tree] def binsToBestSplit(
|
||||
binAggregates: DTStatsAggregator,
|
||||
splits: Array[Array[Split]],
|
||||
featuresForNode: Option[Array[Int]],
|
||||
|
@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging {
|
|||
*
|
||||
* centroidForCategories is a list: (category, centroid)
|
||||
*/
|
||||
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
|
||||
// For categorical variables in multiclass classification,
|
||||
// the bins are ordered by the impurity of their corresponding labels.
|
||||
Range(0, numCategories).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
if (binAggregates.metadata.isMulticlass) {
|
||||
// multiclass classification
|
||||
// For categorical variables in multiclass classification,
|
||||
// the bins are ordered by the impurity of their corresponding labels.
|
||||
categoryStats.calculate()
|
||||
} else if (binAggregates.metadata.isClassification) {
|
||||
// binary classification
|
||||
// For categorical variables in binary classification,
|
||||
// the bins are ordered by the count of class 1.
|
||||
categoryStats.stats(1)
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
} else { // regression or binary classification
|
||||
// For categorical variables in regression and binary classification,
|
||||
// the bins are ordered by the centroid of their corresponding labels.
|
||||
Range(0, numCategories).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
// regression
|
||||
// For categorical variables in regression and binary classification,
|
||||
// the bins are ordered by the prediction.
|
||||
categoryStats.predict
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
|
||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||
|
|
|
@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
|
|||
* @param binAggregates Bin statistics.
|
||||
* @return tuple for best split: (Split, information gain, prediction at node)
|
||||
*/
|
||||
private def binsToBestSplit(
|
||||
private[tree] def binsToBestSplit(
|
||||
binAggregates: DTStatsAggregator,
|
||||
splits: Array[Array[Split]],
|
||||
featuresForNode: Option[Array[Int]],
|
||||
|
@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging {
|
|||
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
||||
val (bestSplit, bestSplitStats) =
|
||||
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
|
||||
val featureIndex = if (featuresForNode.nonEmpty) {
|
||||
featuresForNode.get.apply(featureIndexIdx)
|
||||
} else {
|
||||
featureIndexIdx
|
||||
}
|
||||
val numSplits = binAggregates.metadata.numSplits(featureIndex)
|
||||
if (binAggregates.metadata.isContinuous(featureIndex)) {
|
||||
// Cumulative sum (scanLeft) of bin statistics.
|
||||
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
||||
// that bin + all preceding bins.
|
||||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
||||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
|
||||
splitIndex += 1
|
||||
val featureIndex = if (featuresForNode.nonEmpty) {
|
||||
featuresForNode.get.apply(featureIndexIdx)
|
||||
} else {
|
||||
featureIndexIdx
|
||||
}
|
||||
// Find best split.
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { case splitIdx =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||
(splitIdx, gainStats)
|
||||
}.maxBy(_._2.gain)
|
||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
||||
// Unordered categorical feature
|
||||
val (leftChildOffset, rightChildOffset) =
|
||||
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { splitIndex =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
||||
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||
(splitIndex, gainStats)
|
||||
}.maxBy(_._2.gain)
|
||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||
} else {
|
||||
// Ordered categorical feature
|
||||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
||||
val numBins = binAggregates.metadata.numBins(featureIndex)
|
||||
|
||||
/* Each bin is one category (feature value).
|
||||
* The bins are ordered based on centroidForCategories, and this ordering determines which
|
||||
* splits are considered. (With K categories, we consider K - 1 possible splits.)
|
||||
*
|
||||
* centroidForCategories is a list: (category, centroid)
|
||||
*/
|
||||
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
|
||||
// For categorical variables in multiclass classification,
|
||||
// the bins are ordered by the impurity of their corresponding labels.
|
||||
Range(0, numBins).map { case featureValue =>
|
||||
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
categoryStats.calculate()
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
val numSplits = binAggregates.metadata.numSplits(featureIndex)
|
||||
if (binAggregates.metadata.isContinuous(featureIndex)) {
|
||||
// Cumulative sum (scanLeft) of bin statistics.
|
||||
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
||||
// that bin + all preceding bins.
|
||||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
||||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
|
||||
splitIndex += 1
|
||||
}
|
||||
} else { // regression or binary classification
|
||||
// For categorical variables in regression and binary classification,
|
||||
// the bins are ordered by the centroid of their corresponding labels.
|
||||
Range(0, numBins).map { case featureValue =>
|
||||
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
categoryStats.predict
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
}
|
||||
// Find best split.
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { case splitIdx =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||
(splitIdx, gainStats)
|
||||
}.maxBy(_._2.gain)
|
||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
||||
// Unordered categorical feature
|
||||
val (leftChildOffset, rightChildOffset) =
|
||||
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { splitIndex =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||
(splitIndex, gainStats)
|
||||
}.maxBy(_._2.gain)
|
||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||
} else {
|
||||
// Ordered categorical feature
|
||||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
||||
val numBins = binAggregates.metadata.numBins(featureIndex)
|
||||
|
||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||
|
||||
// bins sorted by centroids
|
||||
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
|
||||
|
||||
logDebug("Sorted centroids for categorical variable = " +
|
||||
categoriesSortedByCentroid.mkString(","))
|
||||
|
||||
// Cumulative sum (scanLeft) of bin statistics.
|
||||
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
||||
// that bin + all preceding bins.
|
||||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
|
||||
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
|
||||
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
|
||||
splitIndex += 1
|
||||
}
|
||||
// lastCategory = index of bin with total aggregates for this (node, feature)
|
||||
val lastCategory = categoriesSortedByCentroid.last._1
|
||||
// Find best split.
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { splitIndex =>
|
||||
val featureValue = categoriesSortedByCentroid(splitIndex)._1
|
||||
val leftChildStats =
|
||||
/* Each bin is one category (feature value).
|
||||
* The bins are ordered based on centroidForCategories, and this ordering determines which
|
||||
* splits are considered. (With K categories, we consider K - 1 possible splits.)
|
||||
*
|
||||
* centroidForCategories is a list: (category, centroid)
|
||||
*/
|
||||
val centroidForCategories = Range(0, numBins).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||
(splitIndex, gainStats)
|
||||
}.maxBy(_._2.gain)
|
||||
val categoriesForSplit =
|
||||
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
||||
val bestFeatureSplit =
|
||||
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
|
||||
(bestFeatureSplit, bestFeatureGainStats)
|
||||
}
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
if (binAggregates.metadata.isMulticlass) {
|
||||
// For categorical variables in multiclass classification,
|
||||
// the bins are ordered by the impurity of their corresponding labels.
|
||||
categoryStats.calculate()
|
||||
} else if (binAggregates.metadata.isClassification) {
|
||||
// For categorical variables in binary classification,
|
||||
// the bins are ordered by the count of class 1.
|
||||
categoryStats.stats(1)
|
||||
} else {
|
||||
// For categorical variables in regression,
|
||||
// the bins are ordered by the prediction.
|
||||
categoryStats.predict
|
||||
}
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
|
||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||
|
||||
// bins sorted by centroids
|
||||
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
|
||||
|
||||
logDebug("Sorted centroids for categorical variable = " +
|
||||
categoriesSortedByCentroid.mkString(","))
|
||||
|
||||
// Cumulative sum (scanLeft) of bin statistics.
|
||||
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
||||
// that bin + all preceding bins.
|
||||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
|
||||
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
|
||||
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
|
||||
splitIndex += 1
|
||||
}
|
||||
// lastCategory = index of bin with total aggregates for this (node, feature)
|
||||
val lastCategory = categoriesSortedByCentroid.last._1
|
||||
// Find best split.
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { splitIndex =>
|
||||
val featureValue = categoriesSortedByCentroid(splitIndex)._1
|
||||
val leftChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||
(splitIndex, gainStats)
|
||||
}.maxBy(_._2.gain)
|
||||
val categoriesForSplit =
|
||||
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
||||
val bestFeatureSplit =
|
||||
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
|
||||
(bestFeatureSplit, bestFeatureGainStats)
|
||||
}
|
||||
}.maxBy(_._2.gain)
|
||||
|
||||
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.impl.TreeTests
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
|
||||
import org.apache.spark.ml.util.MLTestingUtils
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
val model = dt.fit(df)
|
||||
}
|
||||
|
||||
test("Use soft prediction for binary classification with ordered categorical features") {
|
||||
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
|
||||
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||
val data = sc.parallelize(arr)
|
||||
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
|
||||
|
||||
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||
val dt = new DecisionTreeClassifier()
|
||||
.setImpurity("gini")
|
||||
.setMaxDepth(1)
|
||||
.setMaxBins(3)
|
||||
val model = dt.fit(df)
|
||||
model.rootNode match {
|
||||
case n: InternalNode =>
|
||||
n.split match {
|
||||
case s: CategoricalSplit =>
|
||||
assert(s.leftCategories === Array(1.0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
|
|||
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
|
||||
import org.apache.spark.mllib.tree.model._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
|
@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(topNode.rightNode.get.impurity === 0.0)
|
||||
}
|
||||
|
||||
test("Use soft prediction for binary classification with ordered categorical features") {
|
||||
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
|
||||
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||
val input = sc.parallelize(arr)
|
||||
|
||||
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
|
||||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
|
||||
|
||||
val model = new DecisionTree(strategy).run(input)
|
||||
model.topNode.split.get match {
|
||||
case Split(_, _, _, categories: List[Double]) =>
|
||||
assert(categories === List(1.0))
|
||||
}
|
||||
}
|
||||
|
||||
test("Second level node building with vs. without groups") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
|
||||
assert(arr.length === 1000)
|
||||
|
|
Loading…
Reference in a new issue