[SPARK-10524][ML] Use the soft prediction to order categories' bins
JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8734 from viirya/dt-soft-centroids.
This commit is contained in:
parent
0e5ebac3c1
commit
9267bc68fa
|
@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
|
|||
* @param binAggregates Bin statistics.
|
||||
* @return tuple for best split: (Split, information gain, prediction at node)
|
||||
*/
|
||||
private def binsToBestSplit(
|
||||
private[tree] def binsToBestSplit(
|
||||
binAggregates: DTStatsAggregator,
|
||||
splits: Array[Array[Split]],
|
||||
featuresForNode: Option[Array[Int]],
|
||||
|
@ -720,33 +720,31 @@ private[ml] object RandomForest extends Logging {
|
|||
*
|
||||
* centroidForCategories is a list: (category, centroid)
|
||||
*/
|
||||
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
|
||||
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
if (binAggregates.metadata.isMulticlass) {
|
||||
// multiclass classification
|
||||
// For categorical variables in multiclass classification,
|
||||
// the bins are ordered by the impurity of their corresponding labels.
|
||||
Range(0, numCategories).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
categoryStats.calculate()
|
||||
} else if (binAggregates.metadata.isClassification) {
|
||||
// binary classification
|
||||
// For categorical variables in binary classification,
|
||||
// the bins are ordered by the count of class 1.
|
||||
categoryStats.stats(1)
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
} else { // regression or binary classification
|
||||
// regression
|
||||
// For categorical variables in regression and binary classification,
|
||||
// the bins are ordered by the centroid of their corresponding labels.
|
||||
Range(0, numCategories).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
// the bins are ordered by the prediction.
|
||||
categoryStats.predict
|
||||
}
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
}
|
||||
|
||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||
|
||||
|
|
|
@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
|
|||
* @param binAggregates Bin statistics.
|
||||
* @return tuple for best split: (Split, information gain, prediction at node)
|
||||
*/
|
||||
private def binsToBestSplit(
|
||||
private[tree] def binsToBestSplit(
|
||||
binAggregates: DTStatsAggregator,
|
||||
splits: Array[Array[Split]],
|
||||
featuresForNode: Option[Array[Int]],
|
||||
|
@ -828,7 +828,8 @@ object DecisionTree extends Serializable with Logging {
|
|||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { case splitIdx =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
|
@ -844,7 +845,8 @@ object DecisionTree extends Serializable with Logging {
|
|||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { splitIndex =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
||||
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||
val gainStats = calculateGainForSplit(leftChildStats,
|
||||
|
@ -863,31 +865,28 @@ object DecisionTree extends Serializable with Logging {
|
|||
*
|
||||
* centroidForCategories is a list: (category, centroid)
|
||||
*/
|
||||
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
|
||||
val centroidForCategories = Range(0, numBins).map { case featureValue =>
|
||||
val categoryStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
if (binAggregates.metadata.isMulticlass) {
|
||||
// For categorical variables in multiclass classification,
|
||||
// the bins are ordered by the impurity of their corresponding labels.
|
||||
Range(0, numBins).map { case featureValue =>
|
||||
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
categoryStats.calculate()
|
||||
} else if (binAggregates.metadata.isClassification) {
|
||||
// For categorical variables in binary classification,
|
||||
// the bins are ordered by the count of class 1.
|
||||
categoryStats.stats(1)
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
} else { // regression or binary classification
|
||||
// For categorical variables in regression and binary classification,
|
||||
// the bins are ordered by the centroid of their corresponding labels.
|
||||
Range(0, numBins).map { case featureValue =>
|
||||
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||
val centroid = if (categoryStats.count != 0) {
|
||||
// For categorical variables in regression,
|
||||
// the bins are ordered by the prediction.
|
||||
categoryStats.predict
|
||||
}
|
||||
} else {
|
||||
Double.MaxValue
|
||||
}
|
||||
(featureValue, centroid)
|
||||
}
|
||||
}
|
||||
|
||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.impl.TreeTests
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
|
||||
import org.apache.spark.ml.util.MLTestingUtils
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
|||
val model = dt.fit(df)
|
||||
}
|
||||
|
||||
test("Use soft prediction for binary classification with ordered categorical features") {
|
||||
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
|
||||
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||
val data = sc.parallelize(arr)
|
||||
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
|
||||
|
||||
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||
val dt = new DecisionTreeClassifier()
|
||||
.setImpurity("gini")
|
||||
.setMaxDepth(1)
|
||||
.setMaxBins(3)
|
||||
val model = dt.fit(df)
|
||||
model.rootNode match {
|
||||
case n: InternalNode =>
|
||||
n.split match {
|
||||
case s: CategoricalSplit =>
|
||||
assert(s.leftCategories === Array(1.0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
|
|||
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
|
||||
import org.apache.spark.mllib.tree.model._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
|
@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(topNode.rightNode.get.impurity === 0.0)
|
||||
}
|
||||
|
||||
test("Use soft prediction for binary classification with ordered categorical features") {
|
||||
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
|
||||
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||
val input = sc.parallelize(arr)
|
||||
|
||||
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
|
||||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
|
||||
|
||||
val model = new DecisionTree(strategy).run(input)
|
||||
model.topNode.split.get match {
|
||||
case Split(_, _, _, categories: List[Double]) =>
|
||||
assert(categories === List(1.0))
|
||||
}
|
||||
}
|
||||
|
||||
test("Second level node building with vs. without groups") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
|
||||
assert(arr.length === 1000)
|
||||
|
|
Loading…
Reference in a new issue