[SPARK-16957][MLLIB] Use midpoints for split values.

## What changes were proposed in this pull request?

Use midpoints for split values now, and maybe later to make it weighted.

## How was this patch tested?

+ [x] add unit test.
+ [x] revise Split's unit test.

Author: Yan Facai (颜发才) <facai.yan@gmail.com>
Author: 颜发才(Yan Facai) <facai.yan@gmail.com>

Closes #17556 from facaiy/ENH/decision_tree_overflow_and_precision_in_aggregation.
This commit is contained in:
Yan Facai (颜发才) 2017-05-03 10:54:40 +01:00 committed by Sean Owen
parent 16fab6b0ef
commit 7f96f2d7f2
3 changed files with 51 additions and 17 deletions

View file

@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
val splits = if (featureSamples.isEmpty) {
val splits: Array[Double] = if (featureSamples.isEmpty) {
Array.empty[Double]
} else {
val numSplits = metadata.numSplits(featureIndex)
@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging {
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
// if possible splits is not enough or just enough, just return all possible splits
val possibleSplits = valueCounts.length - 1
if (possibleSplits <= numSplits) {
valueCounts.map(_._1).init
if (possibleSplits == 0) {
// constant feature
Array.empty[Double]
} else if (possibleSplits <= numSplits) {
// if possible splits is not enough or just enough, just return all possible splits
(1 to possibleSplits)
.map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
.toArray
} else {
// stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1)
@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
if (previousGap < currentGap) {
splitsBuilder += valueCounts(index - 1)._1
splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
targetCount += stride
}
index += 1

View file

@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(splits.distinct.length === splits.length)
}
// SPARK-16957: Use midpoints for split values.
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
// possibleSplits <= numSplits
{
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2)
assert(splits === expectedSplits)
}
// possibleSplits > numSplits
{
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits)
}
}
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata
{
@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(1.0, 2.0))
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits)
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
.map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(2.0, 3.0))
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
assert(splits === expectedSplits)
}
// find splits when most samples close to the maximum
@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
Array(2), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(1.0))
val expectedSplits = Array((1.0 + 2.0) / 2)
assert(splits === expectedSplits)
}
// find splits for constant feature

View file

@ -199,9 +199,9 @@ class DecisionTree(object):
>>> print(model.toDebugString())
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.0)
Else (feature 0 > 0.5)
Predict: 1.0
<BLANKLINE>
>>> model.predict(array([1.0]))
@ -383,14 +383,14 @@ class RandomForest(object):
Tree 0:
Predict: 1.0
Tree 1:
If (feature 0 <= 1.0)
If (feature 0 <= 1.5)
Predict: 0.0
Else (feature 0 > 1.0)
Else (feature 0 > 1.5)
Predict: 1.0
Tree 2:
If (feature 0 <= 1.0)
If (feature 0 <= 1.5)
Predict: 0.0
Else (feature 0 > 1.0)
Else (feature 0 > 1.5)
Predict: 1.0
<BLANKLINE>
>>> model.predict([2.0])