[SPARK-5989] [MLLIB] Model save/load for LDA
Add support for saving and loading LDA both the local and distributed versions. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6948 from MechCoder/lda_save_load and squashes the following commits: 49bcdce [MechCoder] minor style fixes cc14054 [MechCoder] minor 4587d1d [MechCoder] Minor changes c753122 [MechCoder] Load and save the model in private methods 2782326 [MechCoder] [SPARK-5989] Model save/load for LDA
This commit is contained in:
parent
7f072c3d5e
commit
89db3c0b6e
|
@ -472,7 +472,7 @@ to the algorithm. We then output the topics, represented as probability distribu
|
|||
<div data-lang="scala" markdown="1">
|
||||
|
||||
{% highlight scala %}
|
||||
import org.apache.spark.mllib.clustering.LDA
|
||||
import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel}
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
|
||||
// Load and parse the data
|
||||
|
@ -492,6 +492,11 @@ for (topic <- Range(0, 3)) {
|
|||
for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); }
|
||||
println()
|
||||
}
|
||||
|
||||
// Save and load model.
|
||||
ldaModel.save(sc, "myLDAModel")
|
||||
val sameModel = DistributedLDAModel.load(sc, "myLDAModel")
|
||||
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
@ -551,6 +556,9 @@ public class JavaLDAExample {
|
|||
}
|
||||
System.out.println();
|
||||
}
|
||||
|
||||
ldaModel.save(sc.sc(), "myLDAModel");
|
||||
DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel");
|
||||
}
|
||||
}
|
||||
{% endhighlight %}
|
||||
|
|
|
@ -17,15 +17,25 @@
|
|||
|
||||
package org.apache.spark.mllib.clustering
|
||||
|
||||
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
|
||||
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.api.java.JavaPairRDD
|
||||
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
|
||||
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
|
||||
import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
|
||||
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
|
||||
import org.apache.spark.mllib.util.{Saveable, Loader}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{SQLContext, Row}
|
||||
import org.apache.spark.util.BoundedPriorityQueue
|
||||
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
*
|
||||
|
@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue
|
|||
* including local and distributed data structures.
|
||||
*/
|
||||
@Experimental
|
||||
abstract class LDAModel private[clustering] {
|
||||
abstract class LDAModel private[clustering] extends Saveable {
|
||||
|
||||
/** Number of topics */
|
||||
def k: Int
|
||||
|
@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] (
|
|||
}.toArray
|
||||
}
|
||||
|
||||
override protected def formatVersion = "1.0"
|
||||
|
||||
override def save(sc: SparkContext, path: String): Unit = {
|
||||
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
|
||||
}
|
||||
// TODO
|
||||
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
|
||||
|
||||
|
@ -184,6 +199,80 @@ class LocalLDAModel private[clustering] (
|
|||
|
||||
}
|
||||
|
||||
@Experimental
|
||||
object LocalLDAModel extends Loader[LocalLDAModel] {
|
||||
|
||||
private object SaveLoadV1_0 {
|
||||
|
||||
val thisFormatVersion = "1.0"
|
||||
|
||||
val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"
|
||||
|
||||
// Store the distribution of terms of each topic and the column index in topicsMatrix
|
||||
// as a Row in data.
|
||||
case class Data(topic: Vector, index: Int)
|
||||
|
||||
def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
val k = topicsMatrix.numCols
|
||||
val metadata = compact(render
|
||||
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||
("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
|
||||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
||||
|
||||
val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
|
||||
val topics = Range(0, k).map { topicInd =>
|
||||
Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
|
||||
}.toSeq
|
||||
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
|
||||
}
|
||||
|
||||
def load(sc: SparkContext, path: String): LocalLDAModel = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataFrame = sqlContext.read.parquet(dataPath)
|
||||
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
val topics = dataFrame.collect()
|
||||
val vocabSize = topics(0).getAs[Vector](0).size
|
||||
val k = topics.size
|
||||
|
||||
val brzTopics = BDM.zeros[Double](vocabSize, k)
|
||||
topics.foreach { case Row(vec: Vector, ind: Int) =>
|
||||
brzTopics(::, ind) := vec.toBreeze
|
||||
}
|
||||
new LocalLDAModel(Matrices.fromBreeze(brzTopics))
|
||||
}
|
||||
}
|
||||
|
||||
override def load(sc: SparkContext, path: String): LocalLDAModel = {
|
||||
val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
implicit val formats = DefaultFormats
|
||||
val expectedK = (metadata \ "k").extract[Int]
|
||||
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
|
||||
val classNameV1_0 = SaveLoadV1_0.thisClassName
|
||||
|
||||
val model = (loadedClassName, loadedVersion) match {
|
||||
case (className, "1.0") if className == classNameV1_0 =>
|
||||
SaveLoadV1_0.load(sc, path)
|
||||
case _ => throw new Exception(
|
||||
s"LocalLDAModel.load did not recognize model with (className, format version):" +
|
||||
s"($loadedClassName, $loadedVersion). Supported:\n" +
|
||||
s" ($classNameV1_0, 1.0)")
|
||||
}
|
||||
|
||||
val topicsMatrix = model.topicsMatrix
|
||||
require(expectedK == topicsMatrix.numCols,
|
||||
s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
|
||||
require(expectedVocabSize == topicsMatrix.numRows,
|
||||
s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
|
||||
s"but got ${topicsMatrix.numRows}")
|
||||
model
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
*
|
||||
|
@ -354,4 +443,135 @@ class DistributedLDAModel private (
|
|||
// TODO:
|
||||
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
|
||||
|
||||
override protected def formatVersion = "1.0"
|
||||
|
||||
override def save(sc: SparkContext, path: String): Unit = {
|
||||
DistributedLDAModel.SaveLoadV1_0.save(
|
||||
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
|
||||
iterationTimes)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Experimental
|
||||
object DistributedLDAModel extends Loader[DistributedLDAModel] {
|
||||
|
||||
private object SaveLoadV1_0 {
|
||||
|
||||
val thisFormatVersion = "1.0"
|
||||
|
||||
val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
|
||||
|
||||
// Store globalTopicTotals as a Vector.
|
||||
case class Data(globalTopicTotals: Vector)
|
||||
|
||||
// Store each term and document vertex with an id and the topicWeights.
|
||||
case class VertexData(id: Long, topicWeights: Vector)
|
||||
|
||||
// Store each edge with the source id, destination id and tokenCounts.
|
||||
case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
|
||||
|
||||
def save(
|
||||
sc: SparkContext,
|
||||
path: String,
|
||||
graph: Graph[LDA.TopicCounts, LDA.TokenCount],
|
||||
globalTopicTotals: LDA.TopicCounts,
|
||||
k: Int,
|
||||
vocabSize: Int,
|
||||
docConcentration: Double,
|
||||
topicConcentration: Double,
|
||||
iterationTimes: Array[Double]): Unit = {
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
val metadata = compact(render
|
||||
(("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
|
||||
("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
|
||||
("topicConcentration" -> topicConcentration) ~
|
||||
("iterationTimes" -> iterationTimes.toSeq)))
|
||||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
||||
|
||||
val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
|
||||
sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
|
||||
.write.parquet(newPath)
|
||||
|
||||
val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
|
||||
graph.vertices.map { case (ind, vertex) =>
|
||||
VertexData(ind, Vectors.fromBreeze(vertex))
|
||||
}.toDF().write.parquet(verticesPath)
|
||||
|
||||
val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
|
||||
graph.edges.map { case Edge(srcId, dstId, prop) =>
|
||||
EdgeData(srcId, dstId, prop)
|
||||
}.toDF().write.parquet(edgesPath)
|
||||
}
|
||||
|
||||
def load(
|
||||
sc: SparkContext,
|
||||
path: String,
|
||||
vocabSize: Int,
|
||||
docConcentration: Double,
|
||||
topicConcentration: Double,
|
||||
iterationTimes: Array[Double]): DistributedLDAModel = {
|
||||
val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
|
||||
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
|
||||
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataFrame = sqlContext.read.parquet(dataPath)
|
||||
val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
|
||||
val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
|
||||
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
Loader.checkSchema[VertexData](vertexDataFrame.schema)
|
||||
Loader.checkSchema[EdgeData](edgeDataFrame.schema)
|
||||
val globalTopicTotals: LDA.TopicCounts =
|
||||
dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
|
||||
val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
|
||||
case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
|
||||
}
|
||||
|
||||
val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
|
||||
case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
|
||||
}
|
||||
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
|
||||
|
||||
new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
|
||||
docConcentration, topicConcentration, iterationTimes)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
override def load(sc: SparkContext, path: String): DistributedLDAModel = {
|
||||
val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
implicit val formats = DefaultFormats
|
||||
val expectedK = (metadata \ "k").extract[Int]
|
||||
val vocabSize = (metadata \ "vocabSize").extract[Int]
|
||||
val docConcentration = (metadata \ "docConcentration").extract[Double]
|
||||
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
|
||||
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
|
||||
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
|
||||
|
||||
val model = (loadedClassName, loadedVersion) match {
|
||||
case (className, "1.0") if className == classNameV1_0 => {
|
||||
DistributedLDAModel.SaveLoadV1_0.load(
|
||||
sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
|
||||
}
|
||||
case _ => throw new Exception(
|
||||
s"DistributedLDAModel.load did not recognize model with (className, format version):" +
|
||||
s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
|
||||
}
|
||||
|
||||
require(model.vocabSize == vocabSize,
|
||||
s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
|
||||
require(model.docConcentration == docConcentration,
|
||||
s"DistributedLDAModel requires $docConcentration docConcentration, " +
|
||||
s"got ${model.docConcentration} docConcentration")
|
||||
require(model.topicConcentration == topicConcentration,
|
||||
s"DistributedLDAModel requires $topicConcentration docConcentration, " +
|
||||
s"got ${model.topicConcentration} docConcentration")
|
||||
require(expectedK == model.k,
|
||||
s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
|
||||
model
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
|
@ -217,6 +218,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("model save/load") {
|
||||
// Test for LocalLDAModel.
|
||||
val localModel = new LocalLDAModel(tinyTopics)
|
||||
val tempDir1 = Utils.createTempDir()
|
||||
val path1 = tempDir1.toURI.toString
|
||||
|
||||
// Test for DistributedLDAModel.
|
||||
val k = 3
|
||||
val docConcentration = 1.2
|
||||
val topicConcentration = 1.5
|
||||
val lda = new LDA()
|
||||
lda.setK(k)
|
||||
.setDocConcentration(docConcentration)
|
||||
.setTopicConcentration(topicConcentration)
|
||||
.setMaxIterations(5)
|
||||
.setSeed(12345)
|
||||
val corpus = sc.parallelize(tinyCorpus, 2)
|
||||
val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
|
||||
val tempDir2 = Utils.createTempDir()
|
||||
val path2 = tempDir2.toURI.toString
|
||||
|
||||
try {
|
||||
localModel.save(sc, path1)
|
||||
distributedModel.save(sc, path2)
|
||||
val samelocalModel = LocalLDAModel.load(sc, path1)
|
||||
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
|
||||
assert(samelocalModel.k === localModel.k)
|
||||
assert(samelocalModel.vocabSize === localModel.vocabSize)
|
||||
|
||||
val sameDistributedModel = DistributedLDAModel.load(sc, path2)
|
||||
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
|
||||
assert(distributedModel.k === sameDistributedModel.k)
|
||||
assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
|
||||
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
|
||||
} finally {
|
||||
Utils.deleteRecursively(tempDir1)
|
||||
Utils.deleteRecursively(tempDir2)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[clustering] object LDASuite {
|
||||
|
|
Loading…
Reference in a new issue