[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero
See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test. Author: Xiangrui Meng <meng@databricks.com> Closes #3220 from mengxr/SPARK-4355 and squashes the following commits: 5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up
This commit is contained in:
parent
faeb41de21
commit
84324fbcb9
|
@ -49,6 +49,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
|
|||
private var currMax: BDV[Double] = _
|
||||
private var currMin: BDV[Double] = _
|
||||
|
||||
/**
|
||||
* Adds input value to position i.
|
||||
*/
|
||||
private[this] def add(i: Int, value: Double) = {
|
||||
if (value != 0.0) {
|
||||
if (currMax(i) < value) {
|
||||
currMax(i) = value
|
||||
}
|
||||
if (currMin(i) > value) {
|
||||
currMin(i) = value
|
||||
}
|
||||
|
||||
val prevMean = currMean(i)
|
||||
val diff = value - prevMean
|
||||
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
|
||||
currM2n(i) += (value - currMean(i)) * diff
|
||||
currM2(i) += value * value
|
||||
currL1(i) += math.abs(value)
|
||||
|
||||
nnz(i) += 1.0
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new sample to this summarizer, and update the statistical summary.
|
||||
*
|
||||
|
@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
|
|||
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
|
||||
s" Expecting $n but got ${sample.size}.")
|
||||
|
||||
@inline def update(i: Int, value: Double) = {
|
||||
if (value != 0.0) {
|
||||
if (currMax(i) < value) {
|
||||
currMax(i) = value
|
||||
}
|
||||
if (currMin(i) > value) {
|
||||
currMin(i) = value
|
||||
}
|
||||
|
||||
val tmpPrevMean = currMean(i)
|
||||
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
|
||||
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
|
||||
currM2(i) += value * value
|
||||
currL1(i) += math.abs(value)
|
||||
|
||||
nnz(i) += 1.0
|
||||
}
|
||||
}
|
||||
|
||||
sample match {
|
||||
case dv: DenseVector => {
|
||||
var j = 0
|
||||
while (j < dv.size) {
|
||||
update(j, dv.values(j))
|
||||
add(j, dv.values(j))
|
||||
j += 1
|
||||
}
|
||||
}
|
||||
case sv: SparseVector =>
|
||||
var j = 0
|
||||
while (j < sv.indices.size) {
|
||||
update(sv.indices(j), sv.values(j))
|
||||
add(sv.indices(j), sv.values(j))
|
||||
j += 1
|
||||
}
|
||||
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
|
||||
|
@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
|
|||
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
|
||||
s"Expecting $n but got ${other.n}.")
|
||||
totalCnt += other.totalCnt
|
||||
val deltaMean: BDV[Double] = currMean - other.currMean
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
// merge mean together
|
||||
if (other.currMean(i) != 0.0) {
|
||||
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
|
||||
(nnz(i) + other.nnz(i))
|
||||
}
|
||||
// merge m2n together
|
||||
if (nnz(i) + other.nnz(i) != 0.0) {
|
||||
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
|
||||
(nnz(i) + other.nnz(i))
|
||||
}
|
||||
// merge m2 together
|
||||
if (nnz(i) + other.nnz(i) != 0.0) {
|
||||
val thisNnz = nnz(i)
|
||||
val otherNnz = other.nnz(i)
|
||||
val totalNnz = thisNnz + otherNnz
|
||||
if (totalNnz != 0.0) {
|
||||
val deltaMean = other.currMean(i) - currMean(i)
|
||||
// merge mean together
|
||||
currMean(i) += deltaMean * otherNnz / totalNnz
|
||||
// merge m2n together
|
||||
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
|
||||
// merge m2 together
|
||||
currM2(i) += other.currM2(i)
|
||||
}
|
||||
// merge l1 together
|
||||
if (nnz(i) + other.nnz(i) != 0.0) {
|
||||
// merge l1 together
|
||||
currL1(i) += other.currL1(i)
|
||||
// merge max and min
|
||||
currMax(i) = math.max(currMax(i), other.currMax(i))
|
||||
currMin(i) = math.min(currMin(i), other.currMin(i))
|
||||
}
|
||||
|
||||
if (currMax(i) < other.currMax(i)) {
|
||||
currMax(i) = other.currMax(i)
|
||||
}
|
||||
if (currMin(i) > other.currMin(i)) {
|
||||
currMin(i) = other.currMin(i)
|
||||
}
|
||||
nnz(i) = totalNnz
|
||||
i += 1
|
||||
}
|
||||
nnz += other.nnz
|
||||
} else if (totalCnt == 0 && other.totalCnt != 0) {
|
||||
this.n = other.n
|
||||
this.currMean = other.currMean.copy
|
||||
|
|
|
@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
|
|||
|
||||
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
|
||||
}
|
||||
|
||||
test("merging summarizer when one side has zero mean (SPARK-4355)") {
|
||||
val s0 = new MultivariateOnlineSummarizer()
|
||||
.add(Vectors.dense(2.0))
|
||||
.add(Vectors.dense(2.0))
|
||||
val s1 = new MultivariateOnlineSummarizer()
|
||||
.add(Vectors.dense(1.0))
|
||||
.add(Vectors.dense(-1.0))
|
||||
s0.merge(s1)
|
||||
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue