[SPARK-2359][MLlib] Correlations
Implementation for Pearson and Spearman's correlation.
Author: Doris Xin <doris.s.xin@gmail.com>
Closes #1367 from dorx/correlation and squashes the following commits:
c0dd7dc [Doris Xin] here we go
32d83a3 [Doris Xin] Reviewer comments
4db0da1 [Doris Xin] added private[stat] to Spearman
b716f70 [Doris Xin] minor fixes
6e1b42a [Doris Xin] More comments addressed. Still some open questions
8104f44 [Doris Xin] addressed comments. some open questions still
39387c2 [Doris Xin] added missing header
bd3cf19 [Doris Xin] Merge branch 'master' into correlation
6341884 [Doris Xin] race condition bug squished
bd2bacf [Doris Xin] Race condition bug
b775ff9 [Doris Xin] old wrong impl
534ebf2 [Doris Xin] Merge branch 'master' into correlation
818fa31 [Doris Xin] wip units
9d808ee [Doris Xin] wip units
b843a13 [Doris Xin] revert change in stat counter
28561b6 [Doris Xin] wip
bb2e977 [Doris Xin] minor fix
8e02c63 [Doris Xin] Merge branch 'master' into correlation
2a40aa1 [Doris Xin] initial, untested implementation of Pearson
dfc4854
[Doris Xin] WIP
This commit is contained in:
parent
7b971b91ca
commit
a243364b22
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.stat
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.mllib.linalg.{Matrix, Vector}
|
||||
import org.apache.spark.mllib.stat.correlation.Correlations
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
* API for statistical functions in MLlib
|
||||
*/
|
||||
@Experimental
|
||||
object Statistics {
|
||||
|
||||
/**
|
||||
* Compute the Pearson correlation matrix for the input RDD of Vectors.
|
||||
* Returns NaN if either vector has 0 variance.
|
||||
*
|
||||
* @param X an RDD[Vector] for which the correlation matrix is to be computed.
|
||||
* @return Pearson correlation matrix comparing columns in X.
|
||||
*/
|
||||
def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X)
|
||||
|
||||
/**
|
||||
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
|
||||
* Methods currently supported: `pearson` (default), `spearman`
|
||||
*
|
||||
* Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column
|
||||
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
|
||||
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
|
||||
* avoid recomputing the common lineage.
|
||||
*
|
||||
* @param X an RDD[Vector] for which the correlation matrix is to be computed.
|
||||
* @param method String specifying the method to use for computing correlation.
|
||||
* Supported: `pearson` (default), `spearman`
|
||||
* @return Correlation matrix comparing columns in X.
|
||||
*/
|
||||
def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method)
|
||||
|
||||
/**
|
||||
* Compute the Pearson correlation for the input RDDs.
|
||||
* Columns with 0 covariance produce NaN entries in the correlation matrix.
|
||||
*
|
||||
* @param x RDD[Double] of the same cardinality as y
|
||||
* @param y RDD[Double] of the same cardinality as x
|
||||
* @return A Double containing the Pearson correlation between the two input RDD[Double]s
|
||||
*/
|
||||
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
|
||||
|
||||
/**
|
||||
* Compute the correlation for the input RDDs using the specified method.
|
||||
* Methods currently supported: pearson (default), spearman
|
||||
*
|
||||
* @param x RDD[Double] of the same cardinality as y
|
||||
* @param y RDD[Double] of the same cardinality as x
|
||||
* @param method String specifying the method to use for computing correlation.
|
||||
* Supported: `pearson` (default), `spearman`
|
||||
*@return A Double containing the correlation between the two input RDD[Double]s using the
|
||||
* specified method.
|
||||
*/
|
||||
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.stat.correlation
|
||||
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
* Trait for correlation algorithms.
|
||||
*/
|
||||
private[stat] trait Correlation {
|
||||
|
||||
/**
|
||||
* Compute correlation for two datasets.
|
||||
*/
|
||||
def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double
|
||||
|
||||
/**
|
||||
* Compute the correlation matrix S, for the input matrix, where S(i, j) is the correlation
|
||||
* between column i and j. S(i, j) can be NaN if the correlation is undefined for column i and j.
|
||||
*/
|
||||
def computeCorrelationMatrix(X: RDD[Vector]): Matrix
|
||||
|
||||
/**
|
||||
* Combine the two input RDD[Double]s into an RDD[Vector] and compute the correlation using the
|
||||
* correlation implementation for RDD[Vector]. Can be NaN if correlation is undefined for the
|
||||
* input vectors.
|
||||
*/
|
||||
def computeCorrelationWithMatrixImpl(x: RDD[Double], y: RDD[Double]): Double = {
|
||||
val mat: RDD[Vector] = x.zip(y).map { case (xi, yi) => new DenseVector(Array(xi, yi)) }
|
||||
computeCorrelationMatrix(mat)(0, 1)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Delegates computation to the specific correlation object based on the input method name
|
||||
*
|
||||
* Currently supported correlations: pearson, spearman.
|
||||
* After new correlation algorithms are added, please update the documentation here and in
|
||||
* Statistics.scala for the correlation APIs.
|
||||
*
|
||||
* Maintains the default correlation type, pearson
|
||||
*/
|
||||
private[stat] object Correlations {
|
||||
|
||||
// Note: after new types of correlations are implemented, please update this map
|
||||
val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation))
|
||||
val defaultCorrName: String = "pearson"
|
||||
val defaultCorr: Correlation = nameToObjectMap(defaultCorrName)
|
||||
|
||||
def corr(x: RDD[Double], y: RDD[Double], method: String = defaultCorrName): Double = {
|
||||
val correlation = getCorrelationFromName(method)
|
||||
correlation.computeCorrelation(x, y)
|
||||
}
|
||||
|
||||
def corrMatrix(X: RDD[Vector], method: String = defaultCorrName): Matrix = {
|
||||
val correlation = getCorrelationFromName(method)
|
||||
correlation.computeCorrelationMatrix(X)
|
||||
}
|
||||
|
||||
/**
|
||||
* Match input correlation name with a known name via simple string matching
|
||||
*
|
||||
* private to stat for ease of unit testing
|
||||
*/
|
||||
private[stat] def getCorrelationFromName(method: String): Correlation = {
|
||||
try {
|
||||
nameToObjectMap(method)
|
||||
} catch {
|
||||
case nse: NoSuchElementException =>
|
||||
throw new IllegalArgumentException("Unrecognized method name. Supported correlations: "
|
||||
+ nameToObjectMap.keys.mkString(", "))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.stat.correlation
|
||||
|
||||
import breeze.linalg.{DenseMatrix => BDM}
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector}
|
||||
import org.apache.spark.mllib.linalg.distributed.RowMatrix
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
* Compute Pearson correlation for two RDDs of the type RDD[Double] or the correlation matrix
|
||||
* for an RDD of the type RDD[Vector].
|
||||
*
|
||||
* Definition of Pearson correlation can be found at
|
||||
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
|
||||
*/
|
||||
private[stat] object PearsonCorrelation extends Correlation with Logging {
|
||||
|
||||
/**
|
||||
* Compute the Pearson correlation for two datasets. NaN if either vector has 0 variance.
|
||||
*/
|
||||
override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = {
|
||||
computeCorrelationWithMatrixImpl(x, y)
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the Pearson correlation matrix S, for the input matrix, where S(i, j) is the
|
||||
* correlation between column i and j. 0 covariance results in a correlation value of Double.NaN.
|
||||
*/
|
||||
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
|
||||
val rowMatrix = new RowMatrix(X)
|
||||
val cov = rowMatrix.computeCovariance()
|
||||
computeCorrelationMatrixFromCovariance(cov)
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the Pearson correlation matrix from the covariance matrix.
|
||||
* 0 covariance results in a correlation value of Double.NaN.
|
||||
*/
|
||||
def computeCorrelationMatrixFromCovariance(covarianceMatrix: Matrix): Matrix = {
|
||||
val cov = covarianceMatrix.toBreeze.asInstanceOf[BDM[Double]]
|
||||
val n = cov.cols
|
||||
|
||||
// Compute the standard deviation on the diagonals first
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
// TODO remove once covariance numerical issue resolved.
|
||||
cov(i, i) = if (closeToZero(cov(i, i))) 0.0 else math.sqrt(cov(i, i))
|
||||
i +=1
|
||||
}
|
||||
|
||||
// Loop through columns since cov is column major
|
||||
var j = 0
|
||||
var sigma = 0.0
|
||||
var containNaN = false
|
||||
while (j < n) {
|
||||
sigma = cov(j, j)
|
||||
i = 0
|
||||
while (i < j) {
|
||||
val corr = if (sigma == 0.0 || cov(i, i) == 0.0) {
|
||||
containNaN = true
|
||||
Double.NaN
|
||||
} else {
|
||||
cov(i, j) / (sigma * cov(i, i))
|
||||
}
|
||||
cov(i, j) = corr
|
||||
cov(j, i) = corr
|
||||
i += 1
|
||||
}
|
||||
j += 1
|
||||
}
|
||||
|
||||
// put 1.0 on the diagonals
|
||||
i = 0
|
||||
while (i < n) {
|
||||
cov(i, i) = 1.0
|
||||
i +=1
|
||||
}
|
||||
|
||||
if (containNaN) {
|
||||
logWarning("Pearson correlation matrix contains NaN values.")
|
||||
}
|
||||
|
||||
Matrices.fromBreeze(cov)
|
||||
}
|
||||
|
||||
private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = {
|
||||
math.abs(value) <= threshhold
|
||||
}
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.stat.correlation
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.{Logging, HashPartitioner}
|
||||
import org.apache.spark.SparkContext._
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector}
|
||||
import org.apache.spark.rdd.{CoGroupedRDD, RDD}
|
||||
|
||||
/**
|
||||
* Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix
|
||||
* for an RDD of the type RDD[Vector].
|
||||
*
|
||||
* Definition of Spearman's correlation can be found at
|
||||
* http://en.wikipedia.org/wiki/Spearman's_rank_correlation_coefficient
|
||||
*/
|
||||
private[stat] object SpearmanCorrelation extends Correlation with Logging {
|
||||
|
||||
/**
|
||||
* Compute Spearman's correlation for two datasets.
|
||||
*/
|
||||
override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = {
|
||||
computeCorrelationWithMatrixImpl(x, y)
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the
|
||||
* correlation between column i and j.
|
||||
*
|
||||
* Input RDD[Vector] should be cached or checkpointed if possible since it would be split into
|
||||
* numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector].
|
||||
*/
|
||||
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
|
||||
val indexed = X.zipWithUniqueId()
|
||||
|
||||
val numCols = X.first.size
|
||||
if (numCols > 50) {
|
||||
logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more"
|
||||
+ " than 50 columns.")
|
||||
}
|
||||
val ranks = new Array[RDD[(Long, Double)]](numCols)
|
||||
|
||||
// Note: we use a for loop here instead of a while loop with a single index variable
|
||||
// to avoid race condition caused by closure serialization
|
||||
for (k <- 0 until numCols) {
|
||||
val column = indexed.map { case (vector, index) => (vector(k), index) }
|
||||
ranks(k) = getRanks(column)
|
||||
}
|
||||
|
||||
val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X)
|
||||
PearsonCorrelation.computeCorrelationMatrix(ranksMat)
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the ranks for elements in the input RDD, using the average method for ties.
|
||||
*
|
||||
* With the average method, elements with the same value receive the same rank that's computed
|
||||
* by taking the average of their positions in the sorted list.
|
||||
* e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5]
|
||||
* Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for
|
||||
* ranks in the standard definition for Spearman's correlation. This does not affect the final
|
||||
* results and is slightly more performant.
|
||||
*
|
||||
* @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId)
|
||||
* @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is
|
||||
* copied from the input RDD.
|
||||
*/
|
||||
private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = {
|
||||
// Get elements' positions in the sorted list for computing average rank for duplicate values
|
||||
val sorted = indexed.sortByKey().zipWithIndex()
|
||||
|
||||
val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter =>
|
||||
// add an extra element to signify the end of the list so that flatMap can flush the last
|
||||
// batch of duplicates
|
||||
val padded = iter ++
|
||||
Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L))
|
||||
var lastVal = 0.0
|
||||
var firstRank = 0.0
|
||||
val idBuffer = new ArrayBuffer[Long]()
|
||||
padded.flatMap { case ((v, id), rank) =>
|
||||
if (v == lastVal && id != Long.MinValue) {
|
||||
idBuffer += id
|
||||
Iterator.empty
|
||||
} else {
|
||||
val entries = if (idBuffer.size == 0) {
|
||||
// edge case for the first value matching the initial value of lastVal
|
||||
Iterator.empty
|
||||
} else if (idBuffer.size == 1) {
|
||||
Iterator((idBuffer(0), firstRank))
|
||||
} else {
|
||||
val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0
|
||||
idBuffer.map(id => (id, averageRank))
|
||||
}
|
||||
lastVal = v
|
||||
firstRank = rank
|
||||
idBuffer.clear()
|
||||
idBuffer += id
|
||||
entries
|
||||
}
|
||||
}
|
||||
}
|
||||
ranks
|
||||
}
|
||||
|
||||
private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = {
|
||||
val partitioner = new HashPartitioner(input.partitions.size)
|
||||
val cogrouped = new CoGroupedRDD[Long](ranks, partitioner)
|
||||
cogrouped.map { case (_, values: Seq[Seq[Double]]) => new DenseVector(values.flatten.toArray) }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.stat
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
|
||||
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
|
||||
SpearmanCorrelation}
|
||||
import org.apache.spark.mllib.util.LocalSparkContext
|
||||
|
||||
class CorrelationSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
// test input data
|
||||
val xData = Array(1.0, 0.0, -2.0)
|
||||
val yData = Array(4.0, 5.0, 3.0)
|
||||
val data = Seq(
|
||||
Vectors.dense(1.0, 0.0, 0.0, -2.0),
|
||||
Vectors.dense(4.0, 5.0, 0.0, 3.0),
|
||||
Vectors.dense(6.0, 7.0, 0.0, 8.0),
|
||||
Vectors.dense(9.0, 0.0, 0.0, 1.0)
|
||||
)
|
||||
|
||||
test("corr(x, y) default, pearson") {
|
||||
val x = sc.parallelize(xData)
|
||||
val y = sc.parallelize(yData)
|
||||
val expected = 0.6546537
|
||||
val default = Statistics.corr(x, y)
|
||||
val p1 = Statistics.corr(x, y, "pearson")
|
||||
assert(approxEqual(expected, default))
|
||||
assert(approxEqual(expected, p1))
|
||||
}
|
||||
|
||||
test("corr(x, y) spearman") {
|
||||
val x = sc.parallelize(xData)
|
||||
val y = sc.parallelize(yData)
|
||||
val expected = 0.5
|
||||
val s1 = Statistics.corr(x, y, "spearman")
|
||||
assert(approxEqual(expected, s1))
|
||||
}
|
||||
|
||||
test("corr(X) default, pearson") {
|
||||
val X = sc.parallelize(data)
|
||||
val defaultMat = Statistics.corr(X)
|
||||
val pearsonMat = Statistics.corr(X, "pearson")
|
||||
val expected = BDM(
|
||||
(1.00000000, 0.05564149, Double.NaN, 0.4004714),
|
||||
(0.05564149, 1.00000000, Double.NaN, 0.9135959),
|
||||
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
|
||||
(0.40047142, 0.91359586, Double.NaN,1.0000000))
|
||||
assert(matrixApproxEqual(defaultMat.toBreeze, expected))
|
||||
assert(matrixApproxEqual(pearsonMat.toBreeze, expected))
|
||||
}
|
||||
|
||||
test("corr(X) spearman") {
|
||||
val X = sc.parallelize(data)
|
||||
val spearmanMat = Statistics.corr(X, "spearman")
|
||||
val expected = BDM(
|
||||
(1.0000000, 0.1054093, Double.NaN, 0.4000000),
|
||||
(0.1054093, 1.0000000, Double.NaN, 0.9486833),
|
||||
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
|
||||
(0.4000000, 0.9486833, Double.NaN, 1.0000000))
|
||||
assert(matrixApproxEqual(spearmanMat.toBreeze, expected))
|
||||
}
|
||||
|
||||
test("method identification") {
|
||||
val pearson = PearsonCorrelation
|
||||
val spearman = SpearmanCorrelation
|
||||
|
||||
assert(Correlations.getCorrelationFromName("pearson") === pearson)
|
||||
assert(Correlations.getCorrelationFromName("spearman") === spearman)
|
||||
|
||||
// Should throw IllegalArgumentException
|
||||
try {
|
||||
Correlations.getCorrelationFromName("kendall")
|
||||
assert(false)
|
||||
} catch {
|
||||
case ie: IllegalArgumentException =>
|
||||
}
|
||||
}
|
||||
|
||||
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
|
||||
if (v1.isNaN) {
|
||||
v2.isNaN
|
||||
} else {
|
||||
math.abs(v1 - v2) <= threshold
|
||||
}
|
||||
}
|
||||
|
||||
def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
|
||||
for (i <- 0 until A.rows; j <- 0 until A.cols) {
|
||||
if (!approxEqual(A(i, j), B(i, j), threshold)) {
|
||||
println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue