[SPARK-7475] [MLLIB] adjust ldaExample for online LDA
jira: https://issues.apache.org/jira/browse/SPARK-7475
Add a new argument to specify the algorithm applied to LDA, to exhibit the basic usage of LDAOptimizer.
cc jkbradley
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #6000 from hhbyyh/ldaExample and squashes the following commits:
0a7e2bc [Yuhao Yang] fix according to comments
5810b0f [Yuhao Yang] adjust ldaExample for online LDA
(cherry picked from commit b13162b364
)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
This commit is contained in:
parent
5110f3efe5
commit
e96fc8630e
|
@ -26,7 +26,7 @@ import scopt.OptionParser
|
|||
import org.apache.log4j.{Level, Logger}
|
||||
|
||||
import org.apache.spark.{SparkContext, SparkConf}
|
||||
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
|
||||
import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
@ -48,6 +48,7 @@ object LDAExample {
|
|||
topicConcentration: Double = -1,
|
||||
vocabSize: Int = 10000,
|
||||
stopwordFile: String = "",
|
||||
algorithm: String = "em",
|
||||
checkpointDir: Option[String] = None,
|
||||
checkpointInterval: Int = 10) extends AbstractParams[Params]
|
||||
|
||||
|
@ -78,6 +79,10 @@ object LDAExample {
|
|||
.text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
|
||||
s" default: ${defaultParams.stopwordFile}")
|
||||
.action((x, c) => c.copy(stopwordFile = x))
|
||||
opt[String]("algorithm")
|
||||
.text(s"inference algorithm to use. em and online are supported." +
|
||||
s" default: ${defaultParams.algorithm}")
|
||||
.action((x, c) => c.copy(algorithm = x))
|
||||
opt[String]("checkpointDir")
|
||||
.text(s"Directory for checkpointing intermediate results." +
|
||||
s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
|
||||
|
@ -128,7 +133,17 @@ object LDAExample {
|
|||
|
||||
// Run LDA.
|
||||
val lda = new LDA()
|
||||
lda.setK(params.k)
|
||||
|
||||
val optimizer = params.algorithm.toLowerCase match {
|
||||
case "em" => new EMLDAOptimizer
|
||||
// add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
|
||||
case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize)
|
||||
case _ => throw new IllegalArgumentException(
|
||||
s"Only em, online are supported but got ${params.algorithm}.")
|
||||
}
|
||||
|
||||
lda.setOptimizer(optimizer)
|
||||
.setK(params.k)
|
||||
.setMaxIterations(params.maxIterations)
|
||||
.setDocConcentration(params.docConcentration)
|
||||
.setTopicConcentration(params.topicConcentration)
|
||||
|
@ -137,14 +152,18 @@ object LDAExample {
|
|||
sc.setCheckpointDir(params.checkpointDir.get)
|
||||
}
|
||||
val startTime = System.nanoTime()
|
||||
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
|
||||
val ldaModel = lda.run(corpus)
|
||||
val elapsed = (System.nanoTime() - startTime) / 1e9
|
||||
|
||||
println(s"Finished training LDA model. Summary:")
|
||||
println(s"\t Training time: $elapsed sec")
|
||||
val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
|
||||
|
||||
if (ldaModel.isInstanceOf[DistributedLDAModel]) {
|
||||
val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
|
||||
val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
|
||||
println(s"\t Training data average log likelihood: $avgLogLikelihood")
|
||||
println()
|
||||
}
|
||||
|
||||
// Print the topics, showing the top-weighted terms for each topic.
|
||||
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
|
||||
|
|
Loading…
Reference in a new issue