diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 008dd19c24..82e1ed85a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = if (featureSamples.isEmpty) { + val splits: Array[Double] = if (featureSamples.isEmpty) { Array.empty[Double] } else { val numSplits = metadata.numSplits(featureIndex) @@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.map(_._1).init + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 targetCount += stride } index += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e1ab7c2d65..df155b464c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.distinct.length === splits.length) } + // SPARK-16957: Use midpoints for split values. + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + + // possibleSplits <= numSplits + { + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2) + assert(splits === expectedSplits) + } + + // possibleSplits > numSplits + { + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) + } + } + // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { @@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0, 2.0)) + val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) + .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(2.0, 3.0)) + val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) + assert(splits === expectedSplits) } // find splits when most samples close to the maximum @@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0)) + val expectedSplits = Array((1.0 + 2.0) / 2) + assert(splits === expectedSplits) } // find splits for constant feature diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a6089fc8b9..619fa16d46 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -199,9 +199,9 @@ class DecisionTree(object): >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.0) + If (feature 0 <= 0.5) Predict: 0.0 - Else (feature 0 > 0.0) + Else (feature 0 > 0.5) Predict: 1.0 >>> model.predict(array([1.0])) @@ -383,14 +383,14 @@ class RandomForest(object): Tree 0: Predict: 1.0 Tree 1: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 Tree 2: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 >>> model.predict([2.0])