diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 8e5154b902..b960ae6c07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron} -import breeze.numerics.{digamma, exp, abs} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.numerics.{abs, digamma, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector} import org.apache.spark.rdd.RDD /** @@ -370,7 +370,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { iteration += 1 val k = this.k val vocabSize = this.vocabSize - val Elogbeta = dirichletExpectation(lambda) + val Elogbeta = dirichletExpectation(lambda).t val expElogbeta = exp(Elogbeta) val alpha = this.alpha val gammaShape = this.gammaShape @@ -385,41 +385,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer { case v => throw new IllegalArgumentException("Online LDA does not support vector type " + v.getClass) } + if (!ids.isEmpty) { - // Initialize the variational distribution q(theta|gamma) for the mini-batch - var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K - var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K - var expElogthetad = exp(Elogthetad) // 1 * K - val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K + val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K - var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids - var meanchange = 1D - val ctsVector = new BDV[Double](cts).t // 1 * ids + val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts) // ids - // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { - val lastgamma = gammad - // 1*K 1 * ids ids * k - gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha - Elogthetad = digamma(gammad) - digamma(sum(gammad)) - expElogthetad = exp(Elogthetad) - phinorm = expElogthetad * expElogbetad + 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k - } + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + expElogthetad := exp(digamma(gammad) - digamma(sum(gammad))) + phinorm := expElogbetad * expElogthetad :+ 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } - val m1 = expElogthetad.t - val m2 = (ctsVector / phinorm).t.toDenseVector - var i = 0 - while (i < ids.size) { - stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i) - i += 1 + stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix } } Iterator(stat) } val statsSum: BDM[Double] = stats.reduce(_ += _) - val batchResult = statsSum :* expElogbeta + val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)