[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training

## What changes were proposed in this pull request?

A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored.

## How was this patch tested?

A unit test was added to check that finding splits for a constant feature produces an empty array.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #12374 from sethah/SPARK-14610.
This commit is contained in:
sethah 2016-10-10 17:04:11 -07:00 committed by Joseph K. Bradley
parent 29f186bfdf
commit 03c40202f3
2 changed files with 52 additions and 23 deletions

View file

@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging {
node.stats
}
val validFeatureSplits =
Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
.getOrElse((featureIndexIdx, featureIndexIdx))
}.withFilter { case (_, featureIndex) =>
binAggregates.metadata.numSplits(featureIndex) != 0
}
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging {
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
* @return array of splits
* @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Iterable[Double],
@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
val splits = {
val splits = if (featureSamples.isEmpty) {
Array.empty[Double]
} else {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging {
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
// if possible splits is not enough or just enough, just return all possible splits
val possibleSplits = valueCounts.length
val possibleSplits = valueCounts.length - 1
if (possibleSplits <= numSplits) {
valueCounts.map(_._1)
valueCounts.map(_._1).init
} else {
// stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1)
@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging {
splitsBuilder.result()
}
}
// TODO: Do not fail; just ignore the useless feature.
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")
splits
}

View file

@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 3)
assert(splits === Array(1.0, 2.0))
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
assert(splits(0) === 2.0)
assert(splits(1) === 3.0)
assert(splits === Array(2.0, 3.0))
}
// find splits when most samples close to the maximum
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
Array(2), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 1)
assert(splits(0) === 1.0)
assert(splits === Array(1.0))
}
// find splits for constant feature
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 0, 0).map(_.toDouble)
val featureSamplesEmpty = Array.empty[Double]
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array[Double]())
val splitsEmpty =
RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
assert(splitsEmpty === Array[Double]())
}
}
test("train with constant features") {
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data)
val strategy = new OldStrategy(
OldAlgo.Classification,
Gini,
maxDepth = 2,
numClasses = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
assert(tree.rootNode.prediction === lp.label)
}
test("Multiclass classification with unordered categorical features: split calculations") {