[SPARK-9224] [MLLIB] OnlineLDA Performance Improvements
In-place updates, reduce number of transposes, and vectorize operations in OnlineLDA implementation. Author: Feynman Liang <fliang@databricks.com> Closes #7454 from feynmanliang/OnlineLDA-perf-improvements and squashes the following commits: 78b0f5a [Feynman Liang] Make in-place variables vals, fix BLAS error 7f62a55 [Feynman Liang] --amend c62cb1e [Feynman Liang] Outer product for stats, revert Range slicing aead650 [Feynman Liang] Range slice, in-place update, reduce transposes
This commit is contained in:
parent
e0b7ba59a1
commit
8486cd8531
|
@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering
|
|||
|
||||
import java.util.Random
|
||||
|
||||
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
|
||||
import breeze.numerics.{digamma, exp, abs}
|
||||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
|
||||
import breeze.numerics.{abs, digamma, exp}
|
||||
import breeze.stats.distributions.{Gamma, RandBasis}
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.graphx._
|
||||
import org.apache.spark.graphx.impl.GraphImpl
|
||||
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
|
||||
import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
|
@ -370,7 +370,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
iteration += 1
|
||||
val k = this.k
|
||||
val vocabSize = this.vocabSize
|
||||
val Elogbeta = dirichletExpectation(lambda)
|
||||
val Elogbeta = dirichletExpectation(lambda).t
|
||||
val expElogbeta = exp(Elogbeta)
|
||||
val alpha = this.alpha
|
||||
val gammaShape = this.gammaShape
|
||||
|
@ -385,41 +385,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
case v => throw new IllegalArgumentException("Online LDA does not support vector type "
|
||||
+ v.getClass)
|
||||
}
|
||||
if (!ids.isEmpty) {
|
||||
|
||||
// Initialize the variational distribution q(theta|gamma) for the mini-batch
|
||||
var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
|
||||
var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
|
||||
var expElogthetad = exp(Elogthetad) // 1 * K
|
||||
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
|
||||
// Initialize the variational distribution q(theta|gamma) for the mini-batch
|
||||
val gammad: BDV[Double] =
|
||||
new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
|
||||
val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
|
||||
val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K
|
||||
|
||||
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
|
||||
var meanchange = 1D
|
||||
val ctsVector = new BDV[Double](cts).t // 1 * ids
|
||||
val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
|
||||
var meanchange = 1D
|
||||
val ctsVector = new BDV[Double](cts) // ids
|
||||
|
||||
// Iterate between gamma and phi until convergence
|
||||
while (meanchange > 1e-3) {
|
||||
val lastgamma = gammad
|
||||
// 1*K 1 * ids ids * k
|
||||
gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
|
||||
Elogthetad = digamma(gammad) - digamma(sum(gammad))
|
||||
expElogthetad = exp(Elogthetad)
|
||||
phinorm = expElogthetad * expElogbetad + 1e-100
|
||||
meanchange = sum(abs(gammad - lastgamma)) / k
|
||||
}
|
||||
// Iterate between gamma and phi until convergence
|
||||
while (meanchange > 1e-3) {
|
||||
val lastgamma = gammad.copy
|
||||
// K K * ids ids
|
||||
gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
|
||||
expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
|
||||
phinorm := expElogbetad * expElogthetad :+ 1e-100
|
||||
meanchange = sum(abs(gammad - lastgamma)) / k
|
||||
}
|
||||
|
||||
val m1 = expElogthetad.t
|
||||
val m2 = (ctsVector / phinorm).t.toDenseVector
|
||||
var i = 0
|
||||
while (i < ids.size) {
|
||||
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
|
||||
i += 1
|
||||
stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
|
||||
}
|
||||
}
|
||||
Iterator(stat)
|
||||
}
|
||||
|
||||
val statsSum: BDM[Double] = stats.reduce(_ += _)
|
||||
val batchResult = statsSum :* expElogbeta
|
||||
val batchResult = statsSum :* expElogbeta.t
|
||||
|
||||
// Note that this is an optimization to avoid batch.count
|
||||
update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
|
||||
|
|
Loading…
Reference in a new issue