[SPARK-9224] [MLLIB] OnlineLDA Performance Improvements

In-place updates, reduce number of transposes, and vectorize operations in OnlineLDA implementation.

Author: Feynman Liang <fliang@databricks.com>

Closes #7454 from feynmanliang/OnlineLDA-perf-improvements and squashes the following commits:

78b0f5a [Feynman Liang] Make in-place variables vals, fix BLAS error
7f62a55 [Feynman Liang] --amend
c62cb1e [Feynman Liang] Outer product for stats, revert Range slicing
aead650 [Feynman Liang] Range slice, in-place update, reduce transposes
This commit is contained in:
Feynman Liang 2015-07-22 13:06:01 -07:00 committed by Joseph K. Bradley
parent e0b7ba59a1
commit 8486cd8531

View file

@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering
import java.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
import breeze.numerics.{digamma, exp, abs}
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
import breeze.numerics.{abs, digamma, exp}
import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector}
import org.apache.spark.rdd.RDD
/**
@ -370,7 +370,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
iteration += 1
val k = this.k
val vocabSize = this.vocabSize
val Elogbeta = dirichletExpectation(lambda)
val Elogbeta = dirichletExpectation(lambda).t
val expElogbeta = exp(Elogbeta)
val alpha = this.alpha
val gammaShape = this.gammaShape
@ -385,41 +385,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
case v => throw new IllegalArgumentException("Online LDA does not support vector type "
+ v.getClass)
}
if (!ids.isEmpty) {
// Initialize the variational distribution q(theta|gamma) for the mini-batch
var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
var expElogthetad = exp(Elogthetad) // 1 * K
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
// Initialize the variational distribution q(theta|gamma) for the mini-batch
val gammad: BDV[Double] =
new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts).t // 1 * ids
val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts) // ids
// Iterate between gamma and phi until convergence
while (meanchange > 1e-3) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
Elogthetad = digamma(gammad) - digamma(sum(gammad))
expElogthetad = exp(Elogthetad)
phinorm = expElogthetad * expElogbetad + 1e-100
meanchange = sum(abs(gammad - lastgamma)) / k
}
// Iterate between gamma and phi until convergence
while (meanchange > 1e-3) {
val lastgamma = gammad.copy
// K K * ids ids
gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
phinorm := expElogbetad * expElogthetad :+ 1e-100
meanchange = sum(abs(gammad - lastgamma)) / k
}
val m1 = expElogthetad.t
val m2 = (ctsVector / phinorm).t.toDenseVector
var i = 0
while (i < ids.size) {
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
i += 1
stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
}
}
Iterator(stat)
}
val statsSum: BDM[Double] = stats.reduce(_ += _)
val batchResult = statsSum :* expElogbeta
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count
update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)