[SPARK-3381] [MLlib] Eliminate bins for unordered features in DecisionTrees
For unordered features, it is sufficient to use splits since the threshold of the split corresponds the threshold of the HighSplit of the bin and there is no use of the LowSplit. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4231 from MechCoder/spark-3381 and squashes the following commits: 58c19a5 [MechCoder] COSMIT c274b74 [MechCoder] Remove unordered feature calculation in labeledPointToTreePoint b2b9b89 [MechCoder] COSMIT d3ee042 [MechCoder] [SPARK-3381] [MLlib] Eliminate bins for unordered features
This commit is contained in:
parent
b271c265b7
commit
9b746f3808
|
@ -327,14 +327,14 @@ object DecisionTree extends Serializable with Logging {
|
|||
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
||||
* each (feature, bin).
|
||||
* @param treePoint Data point being aggregated.
|
||||
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
|
||||
* @param splits possible splits indexed (numFeatures)(numSplits)
|
||||
* @param unorderedFeatures Set of indices of unordered features.
|
||||
* @param instanceWeight Weight (importance) of instance in dataset.
|
||||
*/
|
||||
private def mixedBinSeqOp(
|
||||
agg: DTStatsAggregator,
|
||||
treePoint: TreePoint,
|
||||
bins: Array[Array[Bin]],
|
||||
splits: Array[Array[Split]],
|
||||
unorderedFeatures: Set[Int],
|
||||
instanceWeight: Double,
|
||||
featuresForNode: Option[Array[Int]]): Unit = {
|
||||
|
@ -362,7 +362,7 @@ object DecisionTree extends Serializable with Logging {
|
|||
val numSplits = agg.metadata.numSplits(featureIndex)
|
||||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
|
||||
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
|
||||
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
|
||||
instanceWeight)
|
||||
} else {
|
||||
|
@ -506,8 +506,8 @@ object DecisionTree extends Serializable with Logging {
|
|||
if (metadata.unorderedFeatures.isEmpty) {
|
||||
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
|
||||
} else {
|
||||
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
|
||||
instanceWeight, featuresForNode)
|
||||
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
|
||||
metadata.unorderedFeatures, instanceWeight, featuresForNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1024,35 +1024,15 @@ object DecisionTree extends Serializable with Logging {
|
|||
// Categorical feature
|
||||
val featureArity = metadata.featureArity(featureIndex)
|
||||
if (metadata.isUnordered(featureIndex)) {
|
||||
// TODO: The second half of the bins are unused. Actually, we could just use
|
||||
// splits and not build bins for unordered features. That should be part of
|
||||
// a later PR since it will require changing other code (using splits instead
|
||||
// of bins in a few places).
|
||||
// Unordered features
|
||||
// 2^(maxFeatureValue - 1) - 1 combinations
|
||||
splits(featureIndex) = new Array[Split](numSplits)
|
||||
bins(featureIndex) = new Array[Bin](numBins)
|
||||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
val categories: List[Double] =
|
||||
extractMultiClassCategories(splitIndex + 1, featureArity)
|
||||
splits(featureIndex)(splitIndex) =
|
||||
new Split(featureIndex, Double.MinValue, Categorical, categories)
|
||||
bins(featureIndex)(splitIndex) = {
|
||||
if (splitIndex == 0) {
|
||||
new Bin(
|
||||
new DummyCategoricalSplit(featureIndex, Categorical),
|
||||
splits(featureIndex)(0),
|
||||
Categorical,
|
||||
Double.MinValue)
|
||||
} else {
|
||||
new Bin(
|
||||
splits(featureIndex)(splitIndex - 1),
|
||||
splits(featureIndex)(splitIndex),
|
||||
Categorical,
|
||||
Double.MinValue)
|
||||
}
|
||||
}
|
||||
splitIndex += 1
|
||||
}
|
||||
} else {
|
||||
|
@ -1060,8 +1040,11 @@ object DecisionTree extends Serializable with Logging {
|
|||
// Bins correspond to feature values, so we do not need to compute splits or bins
|
||||
// beforehand. Splits are constructed as needed during training.
|
||||
splits(featureIndex) = new Array[Split](0)
|
||||
bins(featureIndex) = new Array[Bin](0)
|
||||
}
|
||||
// For ordered features, bins correspond to feature values.
|
||||
// For unordered categorical features, there is no need to construct the bins.
|
||||
// since there is a one-to-one correspondence between the splits and the bins.
|
||||
bins(featureIndex) = new Array[Bin](0)
|
||||
}
|
||||
featureIndex += 1
|
||||
}
|
||||
|
|
|
@ -55,17 +55,15 @@ private[tree] object TreePoint {
|
|||
input: RDD[LabeledPoint],
|
||||
bins: Array[Array[Bin]],
|
||||
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
|
||||
// Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
|
||||
// Construct arrays for featureArity for efficiency in the inner loop.
|
||||
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
|
||||
val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
|
||||
var featureIndex = 0
|
||||
while (featureIndex < metadata.numFeatures) {
|
||||
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
|
||||
isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
|
||||
featureIndex += 1
|
||||
}
|
||||
input.map { x =>
|
||||
TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
|
||||
TreePoint.labeledPointToTreePoint(x, bins, featureArity)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,19 +72,17 @@ private[tree] object TreePoint {
|
|||
* @param bins Bins for features, of size (numFeatures, numBins).
|
||||
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
|
||||
* for categorical features.
|
||||
* @param isUnordered Array index by feature, with value true for unordered categorical features.
|
||||
*/
|
||||
private def labeledPointToTreePoint(
|
||||
labeledPoint: LabeledPoint,
|
||||
bins: Array[Array[Bin]],
|
||||
featureArity: Array[Int],
|
||||
isUnordered: Array[Boolean]): TreePoint = {
|
||||
featureArity: Array[Int]): TreePoint = {
|
||||
val numFeatures = labeledPoint.features.size
|
||||
val arr = new Array[Int](numFeatures)
|
||||
var featureIndex = 0
|
||||
while (featureIndex < numFeatures) {
|
||||
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
|
||||
isUnordered(featureIndex), bins)
|
||||
bins)
|
||||
featureIndex += 1
|
||||
}
|
||||
new TreePoint(labeledPoint.label, arr)
|
||||
|
@ -96,14 +92,12 @@ private[tree] object TreePoint {
|
|||
* Find bin for one (labeledPoint, feature).
|
||||
*
|
||||
* @param featureArity 0 for continuous features; number of categories for categorical features.
|
||||
* @param isUnorderedFeature (only applies if feature is categorical)
|
||||
* @param bins Bins for features, of size (numFeatures, numBins).
|
||||
*/
|
||||
private def findBin(
|
||||
featureIndex: Int,
|
||||
labeledPoint: LabeledPoint,
|
||||
featureArity: Int,
|
||||
isUnorderedFeature: Boolean,
|
||||
bins: Array[Array[Bin]]): Int = {
|
||||
|
||||
/**
|
||||
|
|
|
@ -190,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
|
|||
assert(splits.length === 2)
|
||||
assert(bins.length === 2)
|
||||
assert(splits(0).length === 3)
|
||||
assert(bins(0).length === 6)
|
||||
assert(bins(0).length === 0)
|
||||
|
||||
// Expecting 2^2 - 1 = 3 bins/splits
|
||||
assert(splits(0)(0).feature === 0)
|
||||
|
@ -228,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
|
|||
assert(splits(1)(2).categories.contains(0.0))
|
||||
assert(splits(1)(2).categories.contains(1.0))
|
||||
|
||||
// Check bins.
|
||||
|
||||
assert(bins(0)(0).category === Double.MinValue)
|
||||
assert(bins(0)(0).lowSplit.categories.length === 0)
|
||||
assert(bins(0)(0).highSplit.categories.length === 1)
|
||||
assert(bins(0)(0).highSplit.categories.contains(0.0))
|
||||
assert(bins(1)(0).category === Double.MinValue)
|
||||
assert(bins(1)(0).lowSplit.categories.length === 0)
|
||||
assert(bins(1)(0).highSplit.categories.length === 1)
|
||||
assert(bins(1)(0).highSplit.categories.contains(0.0))
|
||||
|
||||
assert(bins(0)(1).category === Double.MinValue)
|
||||
assert(bins(0)(1).lowSplit.categories.length === 1)
|
||||
assert(bins(0)(1).lowSplit.categories.contains(0.0))
|
||||
assert(bins(0)(1).highSplit.categories.length === 1)
|
||||
assert(bins(0)(1).highSplit.categories.contains(1.0))
|
||||
assert(bins(1)(1).category === Double.MinValue)
|
||||
assert(bins(1)(1).lowSplit.categories.length === 1)
|
||||
assert(bins(1)(1).lowSplit.categories.contains(0.0))
|
||||
assert(bins(1)(1).highSplit.categories.length === 1)
|
||||
assert(bins(1)(1).highSplit.categories.contains(1.0))
|
||||
|
||||
assert(bins(0)(2).category === Double.MinValue)
|
||||
assert(bins(0)(2).lowSplit.categories.length === 1)
|
||||
assert(bins(0)(2).lowSplit.categories.contains(1.0))
|
||||
assert(bins(0)(2).highSplit.categories.length === 2)
|
||||
assert(bins(0)(2).highSplit.categories.contains(1.0))
|
||||
assert(bins(0)(2).highSplit.categories.contains(0.0))
|
||||
assert(bins(1)(2).category === Double.MinValue)
|
||||
assert(bins(1)(2).lowSplit.categories.length === 1)
|
||||
assert(bins(1)(2).lowSplit.categories.contains(1.0))
|
||||
assert(bins(1)(2).highSplit.categories.length === 2)
|
||||
assert(bins(1)(2).highSplit.categories.contains(1.0))
|
||||
assert(bins(1)(2).highSplit.categories.contains(0.0))
|
||||
|
||||
}
|
||||
|
||||
test("Multiclass classification with ordered categorical features: split and bin calculations") {
|
||||
|
|
Loading…
Reference in a new issue