[SPARK-16957][MLLIB] Use midpoints for split values.
## What changes were proposed in this pull request? Use midpoints for split values now, and maybe later to make it weighted. ## How was this patch tested? + [x] add unit test. + [x] revise Split's unit test. Author: Yan Facai (颜发才) <facai.yan@gmail.com> Author: 颜发才(Yan Facai) <facai.yan@gmail.com> Closes #17556 from facaiy/ENH/decision_tree_overflow_and_precision_in_aggregation.
This commit is contained in:
parent
16fab6b0ef
commit
7f96f2d7f2
|
@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging {
|
|||
require(metadata.isContinuous(featureIndex),
|
||||
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
|
||||
|
||||
val splits = if (featureSamples.isEmpty) {
|
||||
val splits: Array[Double] = if (featureSamples.isEmpty) {
|
||||
Array.empty[Double]
|
||||
} else {
|
||||
val numSplits = metadata.numSplits(featureIndex)
|
||||
|
@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging {
|
|||
// sort distinct values
|
||||
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
|
||||
|
||||
// if possible splits is not enough or just enough, just return all possible splits
|
||||
val possibleSplits = valueCounts.length - 1
|
||||
if (possibleSplits <= numSplits) {
|
||||
valueCounts.map(_._1).init
|
||||
if (possibleSplits == 0) {
|
||||
// constant feature
|
||||
Array.empty[Double]
|
||||
} else if (possibleSplits <= numSplits) {
|
||||
// if possible splits is not enough or just enough, just return all possible splits
|
||||
(1 to possibleSplits)
|
||||
.map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
|
||||
.toArray
|
||||
} else {
|
||||
// stride between splits
|
||||
val stride: Double = numSamples.toDouble / (numSplits + 1)
|
||||
|
@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
|
|||
// makes the gap between currentCount and targetCount smaller,
|
||||
// previous value is a split threshold.
|
||||
if (previousGap < currentGap) {
|
||||
splitsBuilder += valueCounts(index - 1)._1
|
||||
splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
|
||||
targetCount += stride
|
||||
}
|
||||
index += 1
|
||||
|
|
|
@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(splits.distinct.length === splits.length)
|
||||
}
|
||||
|
||||
// SPARK-16957: Use midpoints for split values.
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
)
|
||||
|
||||
// possibleSplits <= numSplits
|
||||
{
|
||||
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((0.0 + 1.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
}
|
||||
|
||||
// possibleSplits > numSplits
|
||||
{
|
||||
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
}
|
||||
}
|
||||
|
||||
// find splits should not return identical splits
|
||||
// when there are not enough split candidates, reduce the number of splits in metadata
|
||||
{
|
||||
|
@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
Array(5), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
|
||||
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits === Array(1.0, 2.0))
|
||||
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
// check returned splits are distinct
|
||||
assert(splits.distinct.length === splits.length)
|
||||
}
|
||||
|
@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
|
||||
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
|
||||
.map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits === Array(2.0, 3.0))
|
||||
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
}
|
||||
|
||||
// find splits when most samples close to the maximum
|
||||
|
@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
Array(2), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
|
||||
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits === Array(1.0))
|
||||
val expectedSplits = Array((1.0 + 2.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
}
|
||||
|
||||
// find splits for constant feature
|
||||
|
|
|
@ -199,9 +199,9 @@ class DecisionTree(object):
|
|||
|
||||
>>> print(model.toDebugString())
|
||||
DecisionTreeModel classifier of depth 1 with 3 nodes
|
||||
If (feature 0 <= 0.0)
|
||||
If (feature 0 <= 0.5)
|
||||
Predict: 0.0
|
||||
Else (feature 0 > 0.0)
|
||||
Else (feature 0 > 0.5)
|
||||
Predict: 1.0
|
||||
<BLANKLINE>
|
||||
>>> model.predict(array([1.0]))
|
||||
|
@ -383,14 +383,14 @@ class RandomForest(object):
|
|||
Tree 0:
|
||||
Predict: 1.0
|
||||
Tree 1:
|
||||
If (feature 0 <= 1.0)
|
||||
If (feature 0 <= 1.5)
|
||||
Predict: 0.0
|
||||
Else (feature 0 > 1.0)
|
||||
Else (feature 0 > 1.5)
|
||||
Predict: 1.0
|
||||
Tree 2:
|
||||
If (feature 0 <= 1.0)
|
||||
If (feature 0 <= 1.5)
|
||||
Predict: 0.0
|
||||
Else (feature 0 > 1.0)
|
||||
Else (feature 0 > 1.5)
|
||||
Predict: 1.0
|
||||
<BLANKLINE>
|
||||
>>> model.predict([2.0])
|
||||
|
|
Loading…
Reference in a new issue