[SPARK-7390] [SQL] Only merge other CovarianceCounter when its count is greater than zero
JIRA: https://issues.apache.org/jira/browse/SPARK-7390
Also fix a minor typo.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #5931 from viirya/fix_covariancecounter and squashes the following commits:
352eda6 [Liang-Chi Hsieh] Only merge other CovarianceCounter when its count is greater than zero.
(cherry picked from commit 90527f5604
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
3024f6b01d
commit
5205eb4c29
|
@ -38,7 +38,7 @@ private[sql] object StatFunctions extends Logging {
|
|||
var yAvg = 0.0 // the mean of all examples seen so far in col2
|
||||
var Ck = 0.0 // the co-moment after k examples
|
||||
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
|
||||
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
|
||||
var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
|
||||
var count = 0L // count of observed examples
|
||||
// add an example to the calculation
|
||||
def add(x: Double, y: Double): this.type = {
|
||||
|
@ -55,15 +55,17 @@ private[sql] object StatFunctions extends Logging {
|
|||
// merge counters from other partitions. Formula can be found at:
|
||||
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
||||
def merge(other: CovarianceCounter): this.type = {
|
||||
val totalCount = count + other.count
|
||||
val deltaX = xAvg - other.xAvg
|
||||
val deltaY = yAvg - other.yAvg
|
||||
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
|
||||
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
|
||||
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
|
||||
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
|
||||
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
|
||||
count = totalCount
|
||||
if (other.count > 0) {
|
||||
val totalCount = count + other.count
|
||||
val deltaX = xAvg - other.xAvg
|
||||
val deltaY = yAvg - other.yAvg
|
||||
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
|
||||
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
|
||||
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
|
||||
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
|
||||
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
|
||||
count = totalCount
|
||||
}
|
||||
this
|
||||
}
|
||||
// return the sample covariance for the observed examples
|
||||
|
|
Loading…
Reference in a new issue