[SPARK-10524][ML] Use the soft prediction to order categories' bins

JIRA: https://issues.apache.org/jira/browse/SPARK-10524

Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction.

Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8734 from viirya/dt-soft-centroids.
This commit is contained in:
Liang-Chi Hsieh 2016-02-09 17:10:55 -08:00 committed by Joseph K. Bradley
parent 0e5ebac3c1
commit 9267bc68fa
4 changed files with 203 additions and 142 deletions

View file

@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
@ -720,33 +720,31 @@ private[ml] object RandomForest extends Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
if (binAggregates.metadata.isMulticlass) {
// multiclass classification
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// binary classification
// For categorical variables in binary classification,
// the bins are ordered by the count of class 1.
categoryStats.stats(1)
} else {
Double.MaxValue
}
(featureValue, centroid)
}
} else { // regression or binary classification
// regression
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
// the bins are ordered by the prediction.
categoryStats.predict
}
} else {
Double.MaxValue
}
(featureValue, centroid)
}
}
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

View file

@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
@ -828,7 +828,8 @@ object DecisionTree extends Serializable with Logging {
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
@ -844,7 +845,8 @@ object DecisionTree extends Serializable with Logging {
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
@ -863,31 +865,28 @@ object DecisionTree extends Serializable with Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
val centroidForCategories = Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// For categorical variables in binary classification,
// the bins are ordered by the count of class 1.
categoryStats.stats(1)
} else {
Double.MaxValue
}
(featureValue, centroid)
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
// For categorical variables in regression,
// the bins are ordered by the prediction.
categoryStats.predict
}
} else {
Double.MaxValue
}
(featureValue, centroid)
}
}
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val model = dt.fit(df)
}
test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxDepth(1)
.setMaxBins(3)
val model = dt.fit(df)
model.rootNode match {
case n: InternalNode =>
n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
}
}
}
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////

View file

@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.rightNode.get.impurity === 0.0)
}
test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val input = sc.parallelize(arr)
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
val model = new DecisionTree(strategy).run(input)
model.topNode.split.get match {
case Split(_, _, _, categories: List[Double]) =>
assert(categories === List(1.0))
}
}
test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)