[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero

See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test.

Author: Xiangrui Meng <meng@databricks.com>

Closes #3220 from mengxr/SPARK-4355 and squashes the following commits:

5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up
This commit is contained in:
Xiangrui Meng 2014-11-12 01:50:11 -08:00
parent faeb41de21
commit 84324fbcb9
2 changed files with 51 additions and 45 deletions

View file

@ -49,6 +49,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currMax: BDV[Double] = _
private var currMin: BDV[Double] = _
/**
* Adds input value to position i.
*/
private[this] def add(i: Int, value: Double) = {
if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}
val prevMean = currMean(i)
val diff = value - prevMean
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * diff
currM2(i) += value * value
currL1(i) += math.abs(value)
nnz(i) += 1.0
}
}
/**
* Add a new sample to this summarizer, and update the statistical summary.
*
@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
@inline def update(i: Int, value: Double) = {
if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}
val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
currM2(i) += value * value
currL1(i) += math.abs(value)
nnz(i) += 1.0
}
}
sample match {
case dv: DenseVector => {
var j = 0
while (j < dv.size) {
update(j, dv.values(j))
add(j, dv.values(j))
j += 1
}
}
case sv: SparseVector =>
var j = 0
while (j < sv.indices.size) {
update(sv.indices(j), sv.values(j))
add(sv.indices(j), sv.values(j))
j += 1
}
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
// merge m2 together
if (nnz(i) + other.nnz(i) != 0.0) {
val thisNnz = nnz(i)
val otherNnz = other.nnz(i)
val totalNnz = thisNnz + otherNnz
if (totalNnz != 0.0) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
currMean(i) += deltaMean * otherNnz / totalNnz
// merge m2n together
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
// merge m2 together
currM2(i) += other.currM2(i)
}
// merge l1 together
if (nnz(i) + other.nnz(i) != 0.0) {
// merge l1 together
currL1(i) += other.currL1(i)
// merge max and min
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
nnz(i) = totalNnz
i += 1
}
nnz += other.nnz
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy

View file

@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}
test("merging summarizer when one side has zero mean (SPARK-4355)") {
val s0 = new MultivariateOnlineSummarizer()
.add(Vectors.dense(2.0))
.add(Vectors.dense(2.0))
val s1 = new MultivariateOnlineSummarizer()
.add(Vectors.dense(1.0))
.add(Vectors.dense(-1.0))
s0.merge(s1)
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
}
}