[SPARK-8936] [MLLIB] OnlineLDA document-topic Dirichlet hyperparameter optimization
Adds `alpha` (document-topic Dirichlet parameter) hyperparameter optimization to `OnlineLDAOptimizer` following Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters. Also introduces a private `setSampleWithReplacement` to `OnlineLDAOptimizer` for unit testing purposes. Author: Feynman Liang <fliang@databricks.com> Closes #7836 from feynmanliang/SPARK-8936-alpha-optimize and squashes the following commits: 4bef484 [Feynman Liang] Documentation improvements c3c6c1d [Feynman Liang] Fix docs 151e859 [Feynman Liang] Fix style fa77518 [Feynman Liang] Hyperparameter optimization
This commit is contained in:
parent
4d5a6e7b60
commit
f51fd6fbb4
|
@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering
|
|||
|
||||
import java.util.Random
|
||||
|
||||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
|
||||
import breeze.numerics.{abs, exp}
|
||||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum}
|
||||
import breeze.numerics.{trigamma, abs, exp}
|
||||
import breeze.stats.distributions.{Gamma, RandBasis}
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
|
@ -239,22 +239,26 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
/** alias for docConcentration */
|
||||
private var alpha: Vector = Vectors.dense(0)
|
||||
|
||||
/** (private[clustering] for debugging) Get docConcentration */
|
||||
/** (for debugging) Get docConcentration */
|
||||
private[clustering] def getAlpha: Vector = alpha
|
||||
|
||||
/** alias for topicConcentration */
|
||||
private var eta: Double = 0
|
||||
|
||||
/** (private[clustering] for debugging) Get topicConcentration */
|
||||
/** (for debugging) Get topicConcentration */
|
||||
private[clustering] def getEta: Double = eta
|
||||
|
||||
private var randomGenerator: java.util.Random = null
|
||||
|
||||
/** (for debugging) Whether to sample mini-batches with replacement. (default = true) */
|
||||
private var sampleWithReplacement: Boolean = true
|
||||
|
||||
// Online LDA specific parameters
|
||||
// Learning rate is: (tau0 + t)^{-kappa}
|
||||
private var tau0: Double = 1024
|
||||
private var kappa: Double = 0.51
|
||||
private var miniBatchFraction: Double = 0.05
|
||||
private var optimizeAlpha: Boolean = false
|
||||
|
||||
// internal data structure
|
||||
private var docs: RDD[(Long, Vector)] = null
|
||||
|
@ -262,7 +266,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
/** Dirichlet parameter for the posterior over topics */
|
||||
private var lambda: BDM[Double] = null
|
||||
|
||||
/** (private[clustering] for debugging) Get parameter for topics */
|
||||
/** (for debugging) Get parameter for topics */
|
||||
private[clustering] def getLambda: BDM[Double] = lambda
|
||||
|
||||
/** Current iteration (count of invocations of [[next()]]) */
|
||||
|
@ -325,7 +329,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
}
|
||||
|
||||
/**
|
||||
* (private[clustering])
|
||||
* Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution)
|
||||
* will be optimized during training.
|
||||
*/
|
||||
def getOptimzeAlpha: Boolean = this.optimizeAlpha
|
||||
|
||||
/**
|
||||
* Sets whether to optimize alpha parameter during training.
|
||||
*
|
||||
* Default: false
|
||||
*/
|
||||
def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = {
|
||||
this.optimizeAlpha = optimizeAlpha
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the Dirichlet parameter for the posterior over topics.
|
||||
* This is only used for testing now. In the future, it can help support training stop/resume.
|
||||
*/
|
||||
|
@ -335,7 +354,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
}
|
||||
|
||||
/**
|
||||
* (private[clustering])
|
||||
* Used for random initialization of the variational parameters.
|
||||
* Larger value produces values closer to 1.0.
|
||||
* This is only used for testing currently.
|
||||
|
@ -345,6 +363,15 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets whether to sample mini-batches with or without replacement. (default = true)
|
||||
* This is only used for testing currently.
|
||||
*/
|
||||
private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = {
|
||||
this.sampleWithReplacement = replace
|
||||
this
|
||||
}
|
||||
|
||||
override private[clustering] def initialize(
|
||||
docs: RDD[(Long, Vector)],
|
||||
lda: LDA): OnlineLDAOptimizer = {
|
||||
|
@ -376,7 +403,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
}
|
||||
|
||||
override private[clustering] def next(): OnlineLDAOptimizer = {
|
||||
val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
|
||||
val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction,
|
||||
randomGenerator.nextLong())
|
||||
if (batch.isEmpty()) return this
|
||||
submitMiniBatch(batch)
|
||||
}
|
||||
|
@ -418,6 +446,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
|
||||
// Note that this is an optimization to avoid batch.count
|
||||
updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
|
||||
if (optimizeAlpha) updateAlpha(gammat)
|
||||
this
|
||||
}
|
||||
|
||||
|
@ -433,13 +462,39 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
|
|||
weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
|
||||
}
|
||||
|
||||
/** Calculates learning rate rho, which decays as a function of [[iteration]] */
|
||||
/**
|
||||
* Update alpha based on `gammat`, the inferred topic distributions for documents in the
|
||||
* current mini-batch. Uses Newton-Rhapson method.
|
||||
* @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters
|
||||
* (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf)
|
||||
*/
|
||||
private def updateAlpha(gammat: BDM[Double]): Unit = {
|
||||
val weight = rho()
|
||||
val N = gammat.rows.toDouble
|
||||
val alpha = this.alpha.toBreeze.toDenseVector
|
||||
val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N
|
||||
val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector)
|
||||
|
||||
val c = N * trigamma(sum(alpha))
|
||||
val q = -N * trigamma(alpha)
|
||||
val b = sum(gradf / q) / (1D / c + sum(1D / q))
|
||||
|
||||
val dalpha = -(gradf - b) / q
|
||||
|
||||
if (all((weight * dalpha + alpha) :> 0D)) {
|
||||
alpha :+= weight * dalpha
|
||||
this.alpha = Vectors.dense(alpha.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Calculate learning rate rho for the current [[iteration]]. */
|
||||
private def rho(): Double = {
|
||||
math.pow(getTau0 + this.iteration, -getKappa)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a random matrix to initialize lambda
|
||||
* Get a random matrix to initialize lambda.
|
||||
*/
|
||||
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
|
||||
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
|
||||
|
|
|
@ -400,6 +400,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("OnlineLDAOptimizer alpha hyperparameter optimization") {
|
||||
val k = 2
|
||||
val docs = sc.parallelize(toyData)
|
||||
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
|
||||
.setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false)
|
||||
val lda = new LDA().setK(k)
|
||||
.setDocConcentration(1D / k)
|
||||
.setTopicConcentration(0.01)
|
||||
.setMaxIterations(100)
|
||||
.setOptimizer(op)
|
||||
.setSeed(12345)
|
||||
val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel]
|
||||
|
||||
/* Verify the results with gensim:
|
||||
import numpy as np
|
||||
from gensim import models
|
||||
corpus = [
|
||||
[(0, 1.0), (1, 1.0)],
|
||||
[(1, 1.0), (2, 1.0)],
|
||||
[(0, 1.0), (2, 1.0)],
|
||||
[(3, 1.0), (4, 1.0)],
|
||||
[(3, 1.0), (5, 1.0)],
|
||||
[(4, 1.0), (5, 1.0)]]
|
||||
np.random.seed(2345)
|
||||
lda = models.ldamodel.LdaModel(
|
||||
corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100,
|
||||
decay=0.51, offset=1024)
|
||||
print(lda.alpha)
|
||||
> [ 0.42582646 0.43511073]
|
||||
*/
|
||||
|
||||
assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05)
|
||||
}
|
||||
|
||||
test("model save/load") {
|
||||
// Test for LocalLDAModel.
|
||||
val localModel = new LocalLDAModel(tinyTopics,
|
||||
|
|
Loading…
Reference in a new issue