[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training
## What changes were proposed in this pull request? A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored. ## How was this patch tested? A unit test was added to check that finding splits for a constant feature produces an empty array. Author: sethah <seth.hendrickson16@gmail.com> Closes #12374 from sethah/SPARK-14610.
This commit is contained in:
parent
29f186bfdf
commit
03c40202f3
|
@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging {
|
|||
node.stats
|
||||
}
|
||||
|
||||
val validFeatureSplits =
|
||||
Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
|
||||
featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
|
||||
.getOrElse((featureIndexIdx, featureIndexIdx))
|
||||
}.withFilter { case (_, featureIndex) =>
|
||||
binAggregates.metadata.numSplits(featureIndex) != 0
|
||||
}
|
||||
|
||||
// For each (feature, split), calculate the gain, and select the best (feature, split).
|
||||
val (bestSplit, bestSplitStats) =
|
||||
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
|
||||
val featureIndex = if (featuresForNode.nonEmpty) {
|
||||
featuresForNode.get.apply(featureIndexIdx)
|
||||
} else {
|
||||
featureIndexIdx
|
||||
}
|
||||
validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
|
||||
val numSplits = binAggregates.metadata.numSplits(featureIndex)
|
||||
if (binAggregates.metadata.isContinuous(featureIndex)) {
|
||||
// Cumulative sum (scanLeft) of bin statistics.
|
||||
|
@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging {
|
|||
* NOTE: `metadata.numbins` will be changed accordingly
|
||||
* if there are not enough splits to be found
|
||||
* @param featureIndex feature index to find splits
|
||||
* @return array of splits
|
||||
* @return array of split thresholds
|
||||
*/
|
||||
private[tree] def findSplitsForContinuousFeature(
|
||||
featureSamples: Iterable[Double],
|
||||
|
@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging {
|
|||
require(metadata.isContinuous(featureIndex),
|
||||
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
|
||||
|
||||
val splits = {
|
||||
val splits = if (featureSamples.isEmpty) {
|
||||
Array.empty[Double]
|
||||
} else {
|
||||
val numSplits = metadata.numSplits(featureIndex)
|
||||
|
||||
// get count for each distinct value
|
||||
|
@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging {
|
|||
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
|
||||
|
||||
// if possible splits is not enough or just enough, just return all possible splits
|
||||
val possibleSplits = valueCounts.length
|
||||
val possibleSplits = valueCounts.length - 1
|
||||
if (possibleSplits <= numSplits) {
|
||||
valueCounts.map(_._1)
|
||||
valueCounts.map(_._1).init
|
||||
} else {
|
||||
// stride between splits
|
||||
val stride: Double = numSamples.toDouble / (numSplits + 1)
|
||||
|
@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging {
|
|||
splitsBuilder.result()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Do not fail; just ignore the useless feature.
|
||||
assert(splits.length > 0,
|
||||
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
|
||||
" Please remove this feature and then try again.")
|
||||
|
||||
splits
|
||||
}
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
)
|
||||
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits.length === 3)
|
||||
assert(splits === Array(1.0, 2.0))
|
||||
// check returned splits are distinct
|
||||
assert(splits.distinct.length === splits.length)
|
||||
}
|
||||
|
@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
)
|
||||
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits.length === 2)
|
||||
assert(splits(0) === 2.0)
|
||||
assert(splits(1) === 3.0)
|
||||
assert(splits === Array(2.0, 3.0))
|
||||
}
|
||||
|
||||
// find splits when most samples close to the maximum
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
Array(2), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits.length === 1)
|
||||
assert(splits(0) === 1.0)
|
||||
assert(splits === Array(1.0))
|
||||
}
|
||||
|
||||
// find splits for constant feature
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(0, 0, 0).map(_.toDouble)
|
||||
val featureSamplesEmpty = Array.empty[Double]
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits === Array[Double]())
|
||||
val splitsEmpty =
|
||||
RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
|
||||
assert(splitsEmpty === Array[Double]())
|
||||
}
|
||||
}
|
||||
|
||||
test("train with constant features") {
|
||||
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
|
||||
val data = Array.fill(5)(lp)
|
||||
val rdd = sc.parallelize(data)
|
||||
val strategy = new OldStrategy(
|
||||
OldAlgo.Classification,
|
||||
Gini,
|
||||
maxDepth = 2,
|
||||
numClasses = 2,
|
||||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
|
||||
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
|
||||
assert(tree.rootNode.impurity === -1.0)
|
||||
assert(tree.depth === 0)
|
||||
assert(tree.rootNode.prediction === lp.label)
|
||||
}
|
||||
|
||||
test("Multiclass classification with unordered categorical features: split calculations") {
|
||||
|
|
Loading…
Reference in a new issue