[SPARK-2152][MLlib] fix bin offset in DecisionTree node aggregations (also resolves SPARK-2160)

Hi, this pull fixes (what I believe to be) a bug in DecisionTree.scala.

In the extractLeftRightNodeAggregates function, the first set of rightNodeAgg values for Regression are set in line 792 as follows:

rightNodeAgg(featureIndex)(2 * (numBins - 2))
  = binData(shift + (2 * numBins - 1)))

Then there is a loop that sets the rest of the values, as in line 809:

rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
  binData(shift + (2 *(numBins - 2 - splitIndex))) +
  rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))

But since splitIndex starts at 1, this ends up skipping a set of binData values.

The changes here address this issue, for both the Regression and Classification cases.

Author: johnnywalleye <jsondag@gmail.com>

Closes #1316 from johnnywalleye/master and squashes the following commits:

73809da [johnnywalleye] fix bin offset in DecisionTree node aggregations
This commit is contained in:
johnnywalleye 2014-07-08 19:17:26 -07:00 committed by Xiangrui Meng
parent ac9cdc116e
commit 1114207cc8

View file

@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
binData(shift + (2 *(numBins - 2 - splitIndex))) +
binData(shift + (2 *(numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
splitIndex += 1
@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
binData(shift + (3 * (numBins - 2 - splitIndex))) +
binData(shift + (3 * (numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
splitIndex += 1