[SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM
This pr uses BLAS.dsyr to replace few implementations in GaussianMixtureEM. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #3949 from viirya/blas_dsyr and squashes the following commits: 4e4d6cf [Liang-Chi Hsieh] Add unit test. Rename function name, modify doc and style. 3f57fd2 [Liang-Chi Hsieh] Add BLAS.dsyr and use it in GaussianMixtureEM.
This commit is contained in:
parent
b6aa557300
commit
e9ca16ec94
|
@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq
|
|||
|
||||
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
|
||||
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
||||
|
@ -151,9 +151,10 @@ class GaussianMixtureEM private (
|
|||
var i = 0
|
||||
while (i < k) {
|
||||
val mu = sums.means(i) / sums.weights(i)
|
||||
val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
|
||||
BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
|
||||
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
|
||||
weights(i) = sums.weights(i) / sumWeights
|
||||
gaussians(i) = new MultivariateGaussian(mu, sigma)
|
||||
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
|
||||
i = i + 1
|
||||
}
|
||||
|
||||
|
@ -211,7 +212,8 @@ private object ExpectationSum {
|
|||
p(i) /= pSum
|
||||
sums.weights(i) += p(i)
|
||||
sums.means(i) += x * p(i)
|
||||
sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
|
||||
BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
|
||||
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
|
||||
i = i + 1
|
||||
}
|
||||
sums
|
||||
|
|
|
@ -228,6 +228,32 @@ private[spark] object BLAS extends Serializable with Logging {
|
|||
}
|
||||
_nativeBLAS
|
||||
}
|
||||
|
||||
/**
|
||||
* A := alpha * x * x^T^ + A
|
||||
* @param alpha a real scalar that will be multiplied to x * x^T^.
|
||||
* @param x the vector x that contains the n elements.
|
||||
* @param A the symmetric matrix A. Size of n x n.
|
||||
*/
|
||||
def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
|
||||
val mA = A.numRows
|
||||
val nA = A.numCols
|
||||
require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
|
||||
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
|
||||
|
||||
nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
|
||||
|
||||
// Fill lower triangular part of A
|
||||
var i = 0
|
||||
while (i < mA) {
|
||||
var j = i + 1
|
||||
while (j < nA) {
|
||||
A(j, i) = A(i, j)
|
||||
j += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* C := alpha * A * B + beta * C
|
||||
|
|
|
@ -127,6 +127,47 @@ class BLASSuite extends FunSuite {
|
|||
}
|
||||
}
|
||||
|
||||
test("syr") {
|
||||
val dA = new DenseMatrix(4, 4,
|
||||
Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
|
||||
val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1))
|
||||
val alpha = 0.15
|
||||
|
||||
val expected = new DenseMatrix(4, 4,
|
||||
Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1,
|
||||
5.4505, 4.1025, 1.4615))
|
||||
|
||||
syr(alpha, x, dA)
|
||||
|
||||
assert(dA ~== expected absTol 1e-15)
|
||||
|
||||
val dB =
|
||||
new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
|
||||
|
||||
withClue("Matrix A must be a symmetric Matrix") {
|
||||
intercept[Exception] {
|
||||
syr(alpha, x, dB)
|
||||
}
|
||||
}
|
||||
|
||||
val dC =
|
||||
new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
|
||||
|
||||
withClue("Size of vector must match the rank of matrix") {
|
||||
intercept[Exception] {
|
||||
syr(alpha, x, dC)
|
||||
}
|
||||
}
|
||||
|
||||
val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
|
||||
|
||||
withClue("Size of vector must match the rank of matrix") {
|
||||
intercept[Exception] {
|
||||
syr(alpha, y, dA)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("gemm") {
|
||||
|
||||
val dA =
|
||||
|
|
Loading…
Reference in a new issue