diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 059b52ef20..ece28848aa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" override def save(sc: SparkContext, path: String): Unit = { - LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, + gammaShape) } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] { // as a Row in data. case class Data(topic: Vector, index: Int) - // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in - // model.predict() - def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + def save( + sc: SparkContext, + path: String, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix @@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) } - def load(sc: SparkContext, path: String): LocalLDAModel = { + def load( + sc: SparkContext, + path: String, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) @@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val topicsMat = Matrices.fromBreeze(brzTopics) // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 - new LocalLDAModel(topicsMat, - Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D) + new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } @@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val gammaShape = (metadata \ "gammaShape").extract[Double] val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => - SaveLoadV1_0.load(sc, path) + SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported:\n" + @@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val thisFormatVersion = "1.0" - val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel" // Store globalTopicTotals as a Vector. case class Data(globalTopicTotals: Vector) @@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { import sqlContext.implicits._ val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration.toArray.toSeq) ~ ("topicConcentration" -> topicConcentration) ~ @@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val topicConcentration = (metadata \ "topicConcentration").extract[Double] val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] val gammaShape = (metadata \ "gammaShape").extract[Double] - val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index aa36336ebb..b91c7cefed 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics, - Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString @@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) assert(samelocalModel.k === localModel.k) assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) val sameDistributedModel = DistributedLDAModel.load(sc, path2) assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) @@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) val graph = distributedModel.graph