[SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility

jira: https://issues.apache.org/jira/browse/SPARK-7090

LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms.
As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly.
Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA.

Concrete changes:

1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm.

2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future)
        -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite.
        -move the code from LDA.initalState to initalState of EMLDAOptimizer

3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer.

4. Change the return type of LDA.run from DistributedLDAModel to LDAModel.

Further work:
add OnlineLDAOptimizer and other possible Optimizers once ready.

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits:

0e2e006 [Yuhao Yang] respond to review comments
08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor
e756ce4 [Yuhao Yang] solve mima exception
d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor
0bb8400 [Yuhao Yang] refactor LDA with Optimizer
ec2f857 [Yuhao Yang] protoptype for discussion
This commit is contained in:
Yuhao Yang 2015-04-27 19:02:51 -07:00 committed by Joseph K. Bradley
parent 62888a4ded
commit 4d9e560b54
8 changed files with 256 additions and 151 deletions

View file

@ -58,7 +58,7 @@ public class JavaLDAExample {
corpus.cache();
// Cluster the documents into three topics using LDA
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);
// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()

View file

@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
@ -137,7 +137,7 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
val ldaModel = lda.run(corpus)
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val elapsed = (System.nanoTime() - startTime) / 1e9
println(s"Finished training LDA model. Summary:")

View file

@ -17,16 +17,11 @@
package org.apache.spark.mllib.clustering
import java.util.Random
import breeze.linalg.{DenseVector => BDV, normalize}
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@ -42,16 +37,9 @@ import org.apache.spark.util.Utils
* - "token": instance of a term appearing in a document
* - "topic": multinomial distribution over words representing some concept
*
* Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
* according to the Asuncion et al. (2009) paper referenced below.
*
* References:
* - Original LDA paper (journal version):
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
* - This class implements their "smoothed" LDA model.
* - Paper which clearly explains several algorithms, including EM:
* Asuncion, Welling, Smyth, and Teh.
* "On Smoothing and Inference for Topic Models." UAI, 2009.
*
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
* (Wikipedia)]]
@ -63,10 +51,11 @@ class LDA private (
private var docConcentration: Double,
private var topicConcentration: Double,
private var seed: Long,
private var checkpointInterval: Int) extends Logging {
private var checkpointInterval: Int,
private var ldaOptimizer: LDAOptimizer) extends Logging {
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
seed = Utils.random.nextLong(), checkpointInterval = 10)
seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
/**
* Number of topics to infer. I.e., the number of soft cluster centers.
@ -220,6 +209,32 @@ class LDA private (
this
}
/** LDAOptimizer used to perform the actual calculation */
def getOptimizer: LDAOptimizer = ldaOptimizer
/**
* LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
*/
def setOptimizer(optimizer: LDAOptimizer): this.type = {
this.ldaOptimizer = optimizer
this
}
/**
* Set the LDAOptimizer used to perform the actual calculation by algorithm name.
* Currently "em" is supported.
*/
def setOptimizer(optimizerName: String): this.type = {
this.ldaOptimizer =
optimizerName.toLowerCase match {
case "em" => new EMLDAOptimizer
case other =>
throw new IllegalArgumentException(s"Only em is supported but got $other.")
}
this
}
/**
* Learn an LDA model using the given dataset.
*
@ -229,9 +244,9 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
def run(documents: RDD[(Long, Vector)]): LDAModel = {
val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
seed, checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
@ -241,12 +256,11 @@ class LDA private (
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
state.getLDAModel(iterationTimes)
}
/** Java-friendly version of [[run()]] */
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
@ -320,88 +334,10 @@ private[clustering] object LDA {
private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
/**
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
*
* @param graph EM graph, storing current parameter estimates in vertex descriptors and
* data (token counts) in edge descriptors.
* @param k Number of topics
* @param vocabSize Number of unique terms
* @param docConcentration "alpha"
* @param topicConcentration "beta" or "eta"
*/
private[clustering] class EMOptimizer(
var graph: Graph[TopicCounts, TokenCount],
val k: Int,
val vocabSize: Int,
val docConcentration: Double,
val topicConcentration: Double,
checkpointInterval: Int) {
private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
graph, checkpointInterval)
def next(): EMOptimizer = {
val eta = topicConcentration
val W = vocabSize
val alpha = docConcentration
val N_k = globalTopicTotals
val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
(edgeContext) => {
// Compute N_{wj} gamma_{wjk}
val N_wj = edgeContext.attr
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
// N_{wj}.
val scaledTopicDistribution: TopicCounts =
computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
edgeContext.sendToDst((false, scaledTopicDistribution))
edgeContext.sendToSrc((false, scaledTopicDistribution))
}
// This is a hack to detect whether we could modify the values in-place.
// TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
(m0, m1) => {
val sum =
if (m0._1) {
m0._2 += m1._2
} else if (m1._1) {
m1._2 += m0._2
} else {
m0._2 + m1._2
}
(true, sum)
}
// M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
val docTopicDistributions: VertexRDD[TopicCounts] =
graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
.mapValues(_._2)
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
graphCheckpointer.updateGraph(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}
/**
* Aggregate distributions over topics from all term vertices.
*
* Note: This executes an action on the graph RDDs.
*/
var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()
private def computeGlobalTopicTotals(): TopicCounts = {
val numTopics = k
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
}
}
/**
* Compute gamma_{wjk}, a distribution over topics k.
*/
private def computePTopic(
private[clustering] def computePTopic(
docTopicCounts: TopicCounts,
termTopicCounts: TopicCounts,
totalTopicCounts: TopicCounts,
@ -427,49 +363,4 @@ private[clustering] object LDA {
// normalize
BDV(gamma_wj) /= sum
}
/**
* Compute bipartite term/doc graph.
*/
private def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): EMOptimizer = {
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
// Add edges for terms with non-zero counts.
termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
Edge(docID, term2index(term), cnt)
}
}
val vocabSize = docs.take(1).head._2.size
// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
def createVertices(): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
partEdges.flatMap { edge =>
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
val sum = gamma * edge.attr
Seq((edge.srcId, sum), (edge.dstId, sum))
}
}
verticesTMP.reduceByKey(_ + _)
}
val docTermVertices = createVertices()
// Partition such that edges are grouped by document
val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
}
}

View file

@ -203,7 +203,7 @@ class DistributedLDAModel private (
import LDA._
private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
state.topicConcentration, iterationTimes)
}

View file

@ -0,0 +1,210 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.clustering
import java.util.Random
import breeze.linalg.{DenseVector => BDV, normalize}
import org.apache.spark.annotation.Experimental
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
/**
* :: Experimental ::
*
* An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can
* hold optimizer-specific parameters for users to set.
*/
@Experimental
trait LDAOptimizer{
/*
DEVELOPERS NOTE:
An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which
stores internal data structure (Graph or Matrix) and other parameters for the algorithm.
The interface is isolated to improve the extensibility of LDA.
*/
/**
* Initializer for the optimizer. LDA passes the common parameters to the optimizer and
* the internal structure can be initialized properly.
*/
private[clustering] def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): LDAOptimizer
private[clustering] def next(): LDAOptimizer
private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel
}
/**
* :: Experimental ::
*
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
*
* Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
* according to the Asuncion et al. (2009) paper referenced below.
*
* References:
* - Original LDA paper (journal version):
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
* - This class implements their "smoothed" LDA model.
* - Paper which clearly explains several algorithms, including EM:
* Asuncion, Welling, Smyth, and Teh.
* "On Smoothing and Inference for Topic Models." UAI, 2009.
*
*/
@Experimental
class EMLDAOptimizer extends LDAOptimizer{
import LDA._
/**
* Following fields will only be initialized through initialState method
*/
private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
private[clustering] var k: Int = 0
private[clustering] var vocabSize: Int = 0
private[clustering] var docConcentration: Double = 0
private[clustering] var topicConcentration: Double = 0
private[clustering] var checkpointInterval: Int = 10
private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null
/**
* Compute bipartite term/doc graph.
*/
private[clustering] override def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): LDAOptimizer = {
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
// Add edges for terms with non-zero counts.
termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
Edge(docID, term2index(term), cnt)
}
}
val vocabSize = docs.take(1).head._2.size
// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
def createVertices(): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
partEdges.flatMap { edge =>
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
val sum = gamma * edge.attr
Seq((edge.srcId, sum), (edge.dstId, sum))
}
}
verticesTMP.reduceByKey(_ + _)
}
val docTermVertices = createVertices()
// Partition such that edges are grouped by document
this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D)
this.k = k
this.vocabSize = vocabSize
this.docConcentration = docConcentration
this.topicConcentration = topicConcentration
this.checkpointInterval = checkpointInterval
this.graphCheckpointer = new
PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
this.globalTopicTotals = computeGlobalTopicTotals()
this
}
private[clustering] override def next(): EMLDAOptimizer = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
val eta = topicConcentration
val W = vocabSize
val alpha = docConcentration
val N_k = globalTopicTotals
val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
(edgeContext) => {
// Compute N_{wj} gamma_{wjk}
val N_wj = edgeContext.attr
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
// N_{wj}.
val scaledTopicDistribution: TopicCounts =
computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
edgeContext.sendToDst((false, scaledTopicDistribution))
edgeContext.sendToSrc((false, scaledTopicDistribution))
}
// This is a hack to detect whether we could modify the values in-place.
// TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
(m0, m1) => {
val sum =
if (m0._1) {
m0._2 += m1._2
} else if (m1._1) {
m1._2 += m0._2
} else {
m0._2 + m1._2
}
(true, sum)
}
// M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
val docTopicDistributions: VertexRDD[TopicCounts] =
graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
.mapValues(_._2)
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
graphCheckpointer.updateGraph(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}
/**
* Aggregate distributions over topics from all term vertices.
*
* Note: This executes an action on the graph RDDs.
*/
private[clustering] var globalTopicTotals: TopicCounts = null
private def computeGlobalTopicTotals(): TopicCounts = {
val numTopics = k
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
}
private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
this.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(this, iterationTimes)
}
}

View file

@ -88,7 +88,7 @@ public class JavaLDASuite implements Serializable {
.setMaxIterations(5)
.setSeed(12345);
DistributedLDAModel model = lda.run(corpus);
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
// Check: basic parameters
LocalLDAModel localModel = model.toLocal();

View file

@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)
val model: DistributedLDAModel = lda.run(corpus)
val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
// Check: basic parameters
val localModel = model.toLocal

View file

@ -72,6 +72,10 @@ object MimaExcludes {
// SPARK-6703 Add getOrCreate method to SparkContext
ProblemFilters.exclude[IncompatibleResultTypeProblem]
("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext")
)++ Seq(
// SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.mllib.clustering.LDA$EMOptimizer")
)
case v if v.startsWith("1.3") =>