[SPARK-9440] [MLLIB] Add hyperparameters to LocalLDAModel save/load
jkbradley MechCoder Resolves blocking issue for SPARK-6793. Please review after #7705 is merged. Author: Feynman Liang <fliang@databricks.com> Closes #7757 from feynmanliang/SPARK-9940-localSaveLoad and squashes the following commits: d0d8cf4 [Feynman Liang] Fix thisClassName 0f30109 [Feynman Liang] Fix tests after changing LDAModel public API dc61981 [Feynman Liang] Add hyperparams to LocalLDAModel save/load
This commit is contained in:
parent
2a9fe4a4e7
commit
a200e64561
|
@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] (
|
||||||
override protected def formatVersion = "1.0"
|
override protected def formatVersion = "1.0"
|
||||||
|
|
||||||
override def save(sc: SparkContext, path: String): Unit = {
|
override def save(sc: SparkContext, path: String): Unit = {
|
||||||
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
|
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
|
||||||
|
gammaShape)
|
||||||
}
|
}
|
||||||
// TODO
|
// TODO
|
||||||
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
|
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
|
||||||
|
@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
|
||||||
// as a Row in data.
|
// as a Row in data.
|
||||||
case class Data(topic: Vector, index: Int)
|
case class Data(topic: Vector, index: Int)
|
||||||
|
|
||||||
// TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in
|
def save(
|
||||||
// model.predict()
|
sc: SparkContext,
|
||||||
def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
|
path: String,
|
||||||
|
topicsMatrix: Matrix,
|
||||||
|
docConcentration: Vector,
|
||||||
|
topicConcentration: Double,
|
||||||
|
gammaShape: Double): Unit = {
|
||||||
val sqlContext = SQLContext.getOrCreate(sc)
|
val sqlContext = SQLContext.getOrCreate(sc)
|
||||||
import sqlContext.implicits._
|
import sqlContext.implicits._
|
||||||
|
|
||||||
val k = topicsMatrix.numCols
|
val k = topicsMatrix.numCols
|
||||||
val metadata = compact(render
|
val metadata = compact(render
|
||||||
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||||
("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
|
("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
|
||||||
|
("docConcentration" -> docConcentration.toArray.toSeq) ~
|
||||||
|
("topicConcentration" -> topicConcentration) ~
|
||||||
|
("gammaShape" -> gammaShape)))
|
||||||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
||||||
|
|
||||||
val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
|
val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
|
||||||
|
@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
|
||||||
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
|
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
def load(sc: SparkContext, path: String): LocalLDAModel = {
|
def load(
|
||||||
|
sc: SparkContext,
|
||||||
|
path: String,
|
||||||
|
docConcentration: Vector,
|
||||||
|
topicConcentration: Double,
|
||||||
|
gammaShape: Double): LocalLDAModel = {
|
||||||
val dataPath = Loader.dataPath(path)
|
val dataPath = Loader.dataPath(path)
|
||||||
val sqlContext = SQLContext.getOrCreate(sc)
|
val sqlContext = SQLContext.getOrCreate(sc)
|
||||||
val dataFrame = sqlContext.read.parquet(dataPath)
|
val dataFrame = sqlContext.read.parquet(dataPath)
|
||||||
|
@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
|
||||||
val topicsMat = Matrices.fromBreeze(brzTopics)
|
val topicsMat = Matrices.fromBreeze(brzTopics)
|
||||||
|
|
||||||
// TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
|
// TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
|
||||||
new LocalLDAModel(topicsMat,
|
new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
|
||||||
Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
|
||||||
implicit val formats = DefaultFormats
|
implicit val formats = DefaultFormats
|
||||||
val expectedK = (metadata \ "k").extract[Int]
|
val expectedK = (metadata \ "k").extract[Int]
|
||||||
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
|
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
|
||||||
|
val docConcentration =
|
||||||
|
Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
|
||||||
|
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
|
||||||
|
val gammaShape = (metadata \ "gammaShape").extract[Double]
|
||||||
val classNameV1_0 = SaveLoadV1_0.thisClassName
|
val classNameV1_0 = SaveLoadV1_0.thisClassName
|
||||||
|
|
||||||
val model = (loadedClassName, loadedVersion) match {
|
val model = (loadedClassName, loadedVersion) match {
|
||||||
case (className, "1.0") if className == classNameV1_0 =>
|
case (className, "1.0") if className == classNameV1_0 =>
|
||||||
SaveLoadV1_0.load(sc, path)
|
SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
|
||||||
case _ => throw new Exception(
|
case _ => throw new Exception(
|
||||||
s"LocalLDAModel.load did not recognize model with (className, format version):" +
|
s"LocalLDAModel.load did not recognize model with (className, format version):" +
|
||||||
s"($loadedClassName, $loadedVersion). Supported:\n" +
|
s"($loadedClassName, $loadedVersion). Supported:\n" +
|
||||||
|
@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
|
||||||
|
|
||||||
val thisFormatVersion = "1.0"
|
val thisFormatVersion = "1.0"
|
||||||
|
|
||||||
val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
|
val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"
|
||||||
|
|
||||||
// Store globalTopicTotals as a Vector.
|
// Store globalTopicTotals as a Vector.
|
||||||
case class Data(globalTopicTotals: Vector)
|
case class Data(globalTopicTotals: Vector)
|
||||||
|
@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
|
||||||
import sqlContext.implicits._
|
import sqlContext.implicits._
|
||||||
|
|
||||||
val metadata = compact(render
|
val metadata = compact(render
|
||||||
(("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
|
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||||
("k" -> k) ~ ("vocabSize" -> vocabSize) ~
|
("k" -> k) ~ ("vocabSize" -> vocabSize) ~
|
||||||
("docConcentration" -> docConcentration.toArray.toSeq) ~
|
("docConcentration" -> docConcentration.toArray.toSeq) ~
|
||||||
("topicConcentration" -> topicConcentration) ~
|
("topicConcentration" -> topicConcentration) ~
|
||||||
|
@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
|
||||||
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
|
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
|
||||||
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
|
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
|
||||||
val gammaShape = (metadata \ "gammaShape").extract[Double]
|
val gammaShape = (metadata \ "gammaShape").extract[Double]
|
||||||
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
|
val classNameV1_0 = SaveLoadV1_0.thisClassName
|
||||||
|
|
||||||
val model = (loadedClassName, loadedVersion) match {
|
val model = (loadedClassName, loadedVersion) match {
|
||||||
case (className, "1.0") if className == classNameV1_0 => {
|
case (className, "1.0") if className == classNameV1_0 => {
|
||||||
|
|
|
@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
test("model save/load") {
|
test("model save/load") {
|
||||||
// Test for LocalLDAModel.
|
// Test for LocalLDAModel.
|
||||||
val localModel = new LocalLDAModel(tinyTopics,
|
val localModel = new LocalLDAModel(tinyTopics,
|
||||||
Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
|
Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
|
||||||
val tempDir1 = Utils.createTempDir()
|
val tempDir1 = Utils.createTempDir()
|
||||||
val path1 = tempDir1.toURI.toString
|
val path1 = tempDir1.toURI.toString
|
||||||
|
|
||||||
|
@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
|
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
|
||||||
assert(samelocalModel.k === localModel.k)
|
assert(samelocalModel.k === localModel.k)
|
||||||
assert(samelocalModel.vocabSize === localModel.vocabSize)
|
assert(samelocalModel.vocabSize === localModel.vocabSize)
|
||||||
|
assert(samelocalModel.docConcentration === localModel.docConcentration)
|
||||||
|
assert(samelocalModel.topicConcentration === localModel.topicConcentration)
|
||||||
|
assert(samelocalModel.gammaShape === localModel.gammaShape)
|
||||||
|
|
||||||
val sameDistributedModel = DistributedLDAModel.load(sc, path2)
|
val sameDistributedModel = DistributedLDAModel.load(sc, path2)
|
||||||
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
|
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
|
||||||
|
@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
|
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
|
||||||
assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
|
assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
|
||||||
assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
|
assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
|
||||||
|
assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
|
||||||
assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
|
assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
|
||||||
|
|
||||||
val graph = distributedModel.graph
|
val graph = distributedModel.graph
|
||||||
|
|
Loading…
Reference in a new issue