[SPARK-10524][ML] Use the soft prediction to order categories' bins
JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8734 from viirya/dt-soft-centroids.
This commit is contained in:
parent
0e5ebac3c1
commit
9267bc68fa
|
@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
|
||||||
* @param binAggregates Bin statistics.
|
* @param binAggregates Bin statistics.
|
||||||
* @return tuple for best split: (Split, information gain, prediction at node)
|
* @return tuple for best split: (Split, information gain, prediction at node)
|
||||||
*/
|
*/
|
||||||
private def binsToBestSplit(
|
private[tree] def binsToBestSplit(
|
||||||
binAggregates: DTStatsAggregator,
|
binAggregates: DTStatsAggregator,
|
||||||
splits: Array[Array[Split]],
|
splits: Array[Array[Split]],
|
||||||
featuresForNode: Option[Array[Int]],
|
featuresForNode: Option[Array[Int]],
|
||||||
|
@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging {
|
||||||
*
|
*
|
||||||
* centroidForCategories is a list: (category, centroid)
|
* centroidForCategories is a list: (category, centroid)
|
||||||
*/
|
*/
|
||||||
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
|
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
|
||||||
// For categorical variables in multiclass classification,
|
val categoryStats =
|
||||||
// the bins are ordered by the impurity of their corresponding labels.
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||||
Range(0, numCategories).map { case featureValue =>
|
val centroid = if (categoryStats.count != 0) {
|
||||||
val categoryStats =
|
if (binAggregates.metadata.isMulticlass) {
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
// multiclass classification
|
||||||
val centroid = if (categoryStats.count != 0) {
|
// For categorical variables in multiclass classification,
|
||||||
|
// the bins are ordered by the impurity of their corresponding labels.
|
||||||
categoryStats.calculate()
|
categoryStats.calculate()
|
||||||
|
} else if (binAggregates.metadata.isClassification) {
|
||||||
|
// binary classification
|
||||||
|
// For categorical variables in binary classification,
|
||||||
|
// the bins are ordered by the count of class 1.
|
||||||
|
categoryStats.stats(1)
|
||||||
} else {
|
} else {
|
||||||
Double.MaxValue
|
// regression
|
||||||
}
|
// For categorical variables in regression and binary classification,
|
||||||
(featureValue, centroid)
|
// the bins are ordered by the prediction.
|
||||||
}
|
|
||||||
} else { // regression or binary classification
|
|
||||||
// For categorical variables in regression and binary classification,
|
|
||||||
// the bins are ordered by the centroid of their corresponding labels.
|
|
||||||
Range(0, numCategories).map { case featureValue =>
|
|
||||||
val categoryStats =
|
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
|
||||||
val centroid = if (categoryStats.count != 0) {
|
|
||||||
categoryStats.predict
|
categoryStats.predict
|
||||||
} else {
|
|
||||||
Double.MaxValue
|
|
||||||
}
|
}
|
||||||
(featureValue, centroid)
|
} else {
|
||||||
|
Double.MaxValue
|
||||||
}
|
}
|
||||||
|
(featureValue, centroid)
|
||||||
}
|
}
|
||||||
|
|
||||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||||
|
|
|
@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
|
||||||
* @param binAggregates Bin statistics.
|
* @param binAggregates Bin statistics.
|
||||||
* @return tuple for best split: (Split, information gain, prediction at node)
|
* @return tuple for best split: (Split, information gain, prediction at node)
|
||||||
*/
|
*/
|
||||||
private def binsToBestSplit(
|
private[tree] def binsToBestSplit(
|
||||||
binAggregates: DTStatsAggregator,
|
binAggregates: DTStatsAggregator,
|
||||||
splits: Array[Array[Split]],
|
splits: Array[Array[Split]],
|
||||||
featuresForNode: Option[Array[Int]],
|
featuresForNode: Option[Array[Int]],
|
||||||
|
@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging {
|
||||||
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
||||||
val (bestSplit, bestSplitStats) =
|
val (bestSplit, bestSplitStats) =
|
||||||
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
|
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
|
||||||
val featureIndex = if (featuresForNode.nonEmpty) {
|
val featureIndex = if (featuresForNode.nonEmpty) {
|
||||||
featuresForNode.get.apply(featureIndexIdx)
|
featuresForNode.get.apply(featureIndexIdx)
|
||||||
} else {
|
} else {
|
||||||
featureIndexIdx
|
featureIndexIdx
|
||||||
}
|
|
||||||
val numSplits = binAggregates.metadata.numSplits(featureIndex)
|
|
||||||
if (binAggregates.metadata.isContinuous(featureIndex)) {
|
|
||||||
// Cumulative sum (scanLeft) of bin statistics.
|
|
||||||
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
|
||||||
// that bin + all preceding bins.
|
|
||||||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
|
||||||
var splitIndex = 0
|
|
||||||
while (splitIndex < numSplits) {
|
|
||||||
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
|
|
||||||
splitIndex += 1
|
|
||||||
}
|
}
|
||||||
// Find best split.
|
val numSplits = binAggregates.metadata.numSplits(featureIndex)
|
||||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
if (binAggregates.metadata.isContinuous(featureIndex)) {
|
||||||
Range(0, numSplits).map { case splitIdx =>
|
// Cumulative sum (scanLeft) of bin statistics.
|
||||||
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
||||||
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
// that bin + all preceding bins.
|
||||||
rightChildStats.subtract(leftChildStats)
|
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
||||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
var splitIndex = 0
|
||||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
while (splitIndex < numSplits) {
|
||||||
val gainStats = calculateGainForSplit(leftChildStats,
|
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
|
||||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
splitIndex += 1
|
||||||
(splitIdx, gainStats)
|
|
||||||
}.maxBy(_._2.gain)
|
|
||||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
|
||||||
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
|
||||||
// Unordered categorical feature
|
|
||||||
val (leftChildOffset, rightChildOffset) =
|
|
||||||
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
|
|
||||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
|
||||||
Range(0, numSplits).map { splitIndex =>
|
|
||||||
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
|
||||||
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
|
||||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
|
||||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
|
||||||
val gainStats = calculateGainForSplit(leftChildStats,
|
|
||||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
|
||||||
(splitIndex, gainStats)
|
|
||||||
}.maxBy(_._2.gain)
|
|
||||||
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
|
||||||
} else {
|
|
||||||
// Ordered categorical feature
|
|
||||||
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
|
||||||
val numBins = binAggregates.metadata.numBins(featureIndex)
|
|
||||||
|
|
||||||
/* Each bin is one category (feature value).
|
|
||||||
* The bins are ordered based on centroidForCategories, and this ordering determines which
|
|
||||||
* splits are considered. (With K categories, we consider K - 1 possible splits.)
|
|
||||||
*
|
|
||||||
* centroidForCategories is a list: (category, centroid)
|
|
||||||
*/
|
|
||||||
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
|
|
||||||
// For categorical variables in multiclass classification,
|
|
||||||
// the bins are ordered by the impurity of their corresponding labels.
|
|
||||||
Range(0, numBins).map { case featureValue =>
|
|
||||||
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
|
||||||
val centroid = if (categoryStats.count != 0) {
|
|
||||||
categoryStats.calculate()
|
|
||||||
} else {
|
|
||||||
Double.MaxValue
|
|
||||||
}
|
|
||||||
(featureValue, centroid)
|
|
||||||
}
|
}
|
||||||
} else { // regression or binary classification
|
// Find best split.
|
||||||
// For categorical variables in regression and binary classification,
|
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||||
// the bins are ordered by the centroid of their corresponding labels.
|
Range(0, numSplits).map { case splitIdx =>
|
||||||
Range(0, numBins).map { case featureValue =>
|
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||||
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
val rightChildStats =
|
||||||
val centroid = if (categoryStats.count != 0) {
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||||
categoryStats.predict
|
rightChildStats.subtract(leftChildStats)
|
||||||
} else {
|
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||||
Double.MaxValue
|
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||||
}
|
val gainStats = calculateGainForSplit(leftChildStats,
|
||||||
(featureValue, centroid)
|
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||||
}
|
(splitIdx, gainStats)
|
||||||
}
|
}.maxBy(_._2.gain)
|
||||||
|
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||||
|
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
|
||||||
|
// Unordered categorical feature
|
||||||
|
val (leftChildOffset, rightChildOffset) =
|
||||||
|
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
|
||||||
|
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||||
|
Range(0, numSplits).map { splitIndex =>
|
||||||
|
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
|
||||||
|
val rightChildStats =
|
||||||
|
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
|
||||||
|
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||||
|
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||||
|
val gainStats = calculateGainForSplit(leftChildStats,
|
||||||
|
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||||
|
(splitIndex, gainStats)
|
||||||
|
}.maxBy(_._2.gain)
|
||||||
|
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
|
||||||
|
} else {
|
||||||
|
// Ordered categorical feature
|
||||||
|
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
|
||||||
|
val numBins = binAggregates.metadata.numBins(featureIndex)
|
||||||
|
|
||||||
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
/* Each bin is one category (feature value).
|
||||||
|
* The bins are ordered based on centroidForCategories, and this ordering determines which
|
||||||
// bins sorted by centroids
|
* splits are considered. (With K categories, we consider K - 1 possible splits.)
|
||||||
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
|
*
|
||||||
|
* centroidForCategories is a list: (category, centroid)
|
||||||
logDebug("Sorted centroids for categorical variable = " +
|
*/
|
||||||
categoriesSortedByCentroid.mkString(","))
|
val centroidForCategories = Range(0, numBins).map { case featureValue =>
|
||||||
|
val categoryStats =
|
||||||
// Cumulative sum (scanLeft) of bin statistics.
|
|
||||||
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
|
||||||
// that bin + all preceding bins.
|
|
||||||
var splitIndex = 0
|
|
||||||
while (splitIndex < numSplits) {
|
|
||||||
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
|
|
||||||
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
|
|
||||||
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
|
|
||||||
splitIndex += 1
|
|
||||||
}
|
|
||||||
// lastCategory = index of bin with total aggregates for this (node, feature)
|
|
||||||
val lastCategory = categoriesSortedByCentroid.last._1
|
|
||||||
// Find best split.
|
|
||||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
|
||||||
Range(0, numSplits).map { splitIndex =>
|
|
||||||
val featureValue = categoriesSortedByCentroid(splitIndex)._1
|
|
||||||
val leftChildStats =
|
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||||
val rightChildStats =
|
val centroid = if (categoryStats.count != 0) {
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
if (binAggregates.metadata.isMulticlass) {
|
||||||
rightChildStats.subtract(leftChildStats)
|
// For categorical variables in multiclass classification,
|
||||||
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
// the bins are ordered by the impurity of their corresponding labels.
|
||||||
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
categoryStats.calculate()
|
||||||
val gainStats = calculateGainForSplit(leftChildStats,
|
} else if (binAggregates.metadata.isClassification) {
|
||||||
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
// For categorical variables in binary classification,
|
||||||
(splitIndex, gainStats)
|
// the bins are ordered by the count of class 1.
|
||||||
}.maxBy(_._2.gain)
|
categoryStats.stats(1)
|
||||||
val categoriesForSplit =
|
} else {
|
||||||
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
// For categorical variables in regression,
|
||||||
val bestFeatureSplit =
|
// the bins are ordered by the prediction.
|
||||||
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
|
categoryStats.predict
|
||||||
(bestFeatureSplit, bestFeatureGainStats)
|
}
|
||||||
}
|
} else {
|
||||||
|
Double.MaxValue
|
||||||
|
}
|
||||||
|
(featureValue, centroid)
|
||||||
|
}
|
||||||
|
|
||||||
|
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
|
||||||
|
|
||||||
|
// bins sorted by centroids
|
||||||
|
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
|
||||||
|
|
||||||
|
logDebug("Sorted centroids for categorical variable = " +
|
||||||
|
categoriesSortedByCentroid.mkString(","))
|
||||||
|
|
||||||
|
// Cumulative sum (scanLeft) of bin statistics.
|
||||||
|
// Afterwards, binAggregates for a bin is the sum of aggregates for
|
||||||
|
// that bin + all preceding bins.
|
||||||
|
var splitIndex = 0
|
||||||
|
while (splitIndex < numSplits) {
|
||||||
|
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
|
||||||
|
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
|
||||||
|
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
|
||||||
|
splitIndex += 1
|
||||||
|
}
|
||||||
|
// lastCategory = index of bin with total aggregates for this (node, feature)
|
||||||
|
val lastCategory = categoriesSortedByCentroid.last._1
|
||||||
|
// Find best split.
|
||||||
|
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||||
|
Range(0, numSplits).map { splitIndex =>
|
||||||
|
val featureValue = categoriesSortedByCentroid(splitIndex)._1
|
||||||
|
val leftChildStats =
|
||||||
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
|
||||||
|
val rightChildStats =
|
||||||
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
|
||||||
|
rightChildStats.subtract(leftChildStats)
|
||||||
|
predictWithImpurity = Some(predictWithImpurity.getOrElse(
|
||||||
|
calculatePredictImpurity(leftChildStats, rightChildStats)))
|
||||||
|
val gainStats = calculateGainForSplit(leftChildStats,
|
||||||
|
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
|
||||||
|
(splitIndex, gainStats)
|
||||||
|
}.maxBy(_._2.gain)
|
||||||
|
val categoriesForSplit =
|
||||||
|
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
|
||||||
|
val bestFeatureSplit =
|
||||||
|
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
|
||||||
|
(bestFeatureSplit, bestFeatureGainStats)
|
||||||
|
}
|
||||||
}.maxBy(_._2.gain)
|
}.maxBy(_._2.gain)
|
||||||
|
|
||||||
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
|
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.impl.TreeTests
|
import org.apache.spark.ml.impl.TreeTests
|
||||||
import org.apache.spark.ml.param.ParamsSuite
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.tree.LeafNode
|
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
|
||||||
import org.apache.spark.ml.util.MLTestingUtils
|
import org.apache.spark.ml.util.MLTestingUtils
|
||||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
val model = dt.fit(df)
|
val model = dt.fit(df)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Use soft prediction for binary classification with ordered categorical features") {
|
||||||
|
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
|
||||||
|
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
|
||||||
|
val arr = Array(
|
||||||
|
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(1.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
|
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||||
|
val data = sc.parallelize(arr)
|
||||||
|
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
|
||||||
|
|
||||||
|
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||||
|
val dt = new DecisionTreeClassifier()
|
||||||
|
.setImpurity("gini")
|
||||||
|
.setMaxDepth(1)
|
||||||
|
.setMaxBins(3)
|
||||||
|
val model = dt.fit(df)
|
||||||
|
model.rootNode match {
|
||||||
|
case n: InternalNode =>
|
||||||
|
n.split match {
|
||||||
|
case s: CategoricalSplit =>
|
||||||
|
assert(s.leftCategories === Array(1.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of model save/load
|
// Tests of model save/load
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
|
||||||
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
|
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
|
||||||
import org.apache.spark.mllib.tree.model._
|
import org.apache.spark.mllib.tree.model._
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
assert(topNode.rightNode.get.impurity === 0.0)
|
assert(topNode.rightNode.get.impurity === 0.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Use soft prediction for binary classification with ordered categorical features") {
|
||||||
|
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
|
||||||
|
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
|
||||||
|
val arr = Array(
|
||||||
|
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(1.0, Vectors.dense(0.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(1.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
|
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||||
|
val input = sc.parallelize(arr)
|
||||||
|
|
||||||
|
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||||
|
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
|
||||||
|
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
|
||||||
|
|
||||||
|
val model = new DecisionTree(strategy).run(input)
|
||||||
|
model.topNode.split.get match {
|
||||||
|
case Split(_, _, _, categories: List[Double]) =>
|
||||||
|
assert(categories === List(1.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("Second level node building with vs. without groups") {
|
test("Second level node building with vs. without groups") {
|
||||||
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
|
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
|
||||||
assert(arr.length === 1000)
|
assert(arr.length === 1000)
|
||||||
|
|
Loading…
Reference in a new issue