From bfd75cdfb22a8c2fb005da597621e1ccd3990e82 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Wed, 16 May 2018 17:54:06 -0700 Subject: [PATCH] [SPARK-22210][ML] Add seed for LDA variationalTopicInference ## What changes were proposed in this pull request? - Add seed parameter for variationalTopicInference - Add seed for calling variationalTopicInference in submitMiniBatch - Add var seed in LDAModel so that it can take the seed from LDA and use it for the function call of variationalTopicInference in logLikelihoodBound, topicDistributions, getTopicDistributionMethod, and topicDistribution. ## How was this patch tested? Check the test result in mllib.clustering.LDASuite to make sure the result is repeatable with the seed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21183 from ludatabricks/SPARK-22210. --- .../org/apache/spark/ml/clustering/LDA.scala | 6 ++- .../spark/mllib/clustering/LDAModel.scala | 34 +++++++++++++--- .../spark/mllib/clustering/LDAOptimizer.scala | 40 +++++++++++-------- .../apache/spark/ml/clustering/LDASuite.scala | 6 +++ 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index afe599cd16..fed42c959b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -569,10 +569,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248..f915062d77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f0..f8e5f3ed76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 8d728f063d..4d84820503 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -253,6 +253,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") {