[SPARK-16447][ML][SPARKR] LDA wrapper in SparkR

## What changes were proposed in this pull request?

Add LDA Wrapper in SparkR with the following interfaces:

- spark.lda(data, ...)

- spark.posterior(object, newData, ...)

- spark.perplexity(object, ...)

- summary(object)

- write.ml(object)

- read.ml(path)

## How was this patch tested?

Test with SparkR unit test.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #14229 from yinxusen/SPARK-16447.
This commit is contained in:
Xusen Yin 2016-08-18 05:33:52 -07:00 committed by Felix Cheung
parent 68f5087d21
commit b72bb62d42
7 changed files with 490 additions and 2 deletions

View file

@ -25,6 +25,9 @@ exportMethods("glm",
"fitted", "fitted",
"spark.naiveBayes", "spark.naiveBayes",
"spark.survreg", "spark.survreg",
"spark.lda",
"spark.posterior",
"spark.perplexity",
"spark.isoreg", "spark.isoreg",
"spark.gaussianMixture") "spark.gaussianMixture")

View file

@ -1304,6 +1304,19 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @export #' @export
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
#' @rdname spark.lda
#' @param ... Additional parameters to tune LDA.
#' @export
setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
#' @rdname spark.lda
#' @export
setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })
#' @rdname spark.lda
#' @export
setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
#' @rdname spark.isoreg #' @rdname spark.isoreg
#' @export #' @export
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
@ -1315,6 +1328,7 @@ setGeneric("spark.gaussianMixture",
standardGeneric("spark.gaussianMixture") standardGeneric("spark.gaussianMixture")
}) })
#' write.ml
#' @rdname write.ml #' @rdname write.ml
#' @export #' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })

View file

@ -39,6 +39,13 @@ setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
#' @note NaiveBayesModel since 2.0.0 #' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj")) setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' S4 class that represents an LDAModel
#'
#' @param jobj a Java object reference to the backing Scala LDAWrapper
#' @export
#' @note LDAModel since 2.1.0
setClass("LDAModel", representation(jobj = "jobj"))
#' S4 class that represents a AFTSurvivalRegressionModel #' S4 class that represents a AFTSurvivalRegressionModel
#' #'
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
@ -75,7 +82,7 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @name write.ml #' @name write.ml
#' @export #' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture} #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} #' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
#' @seealso \link{spark.isoreg} #' @seealso \link{spark.isoreg}
#' @seealso \link{read.ml} #' @seealso \link{read.ml}
NULL NULL
@ -315,6 +322,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
return(list(apriori = apriori, tables = tables)) return(list(apriori = apriori, tables = tables))
}) })
# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
#' @param newData A SparkDataFrame for testing
#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities
#' vectors named "topicDistribution"
#' @rdname spark.lda
#' @aliases spark.posterior,LDAModel,SparkDataFrame-method
#' @export
#' @note spark.posterior(LDAModel) since 2.1.0
setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}.
#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10.
#' @return \code{summary} returns a list containing
#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for
#' the prior placed on documents distributions over topics \code{theta}}
#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or
#' \code{eta} for the prior placed on topic distributions over terms}
#' \item{\code{logLikelihood}}{log likelihood of the entire corpus}
#' \item{\code{logPerplexity}}{log perplexity}
#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model}
#' \item{\code{vocabSize}}{number of terms in the corpus}
#' \item{\code{topics}}{top 10 terms and their weights of all topics}
#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
#' used as training set}
#' @rdname spark.lda
#' @aliases summary,LDAModel-method
#' @export
#' @note summary(LDAModel) since 2.1.0
setMethod("summary", signature(object = "LDAModel"),
function(object, maxTermsPerTopic) {
maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic))
jobj <- object@jobj
docConcentration <- callJMethod(jobj, "docConcentration")
topicConcentration <- callJMethod(jobj, "topicConcentration")
logLikelihood <- callJMethod(jobj, "logLikelihood")
logPerplexity <- callJMethod(jobj, "logPerplexity")
isDistributed <- callJMethod(jobj, "isDistributed")
vocabSize <- callJMethod(jobj, "vocabSize")
topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
vocabulary <- callJMethod(jobj, "vocabulary")
return(list(docConcentration = unlist(docConcentration),
topicConcentration = topicConcentration,
logLikelihood = logLikelihood, logPerplexity = logPerplexity,
isDistributed = isDistributed, vocabSize = vocabSize,
topics = topics,
vocabulary = unlist(vocabulary)))
})
# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log
#' perplexity of the training data if missing argument "data".
#' @rdname spark.lda
#' @aliases spark.perplexity,LDAModel-method
#' @export
#' @note spark.perplexity(LDAModel) since 2.1.0
setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
function(object, data) {
return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
})
# Saves the Latent Dirichlet Allocation model to the input path.
#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.lda
#' @aliases write.ml,LDAModel,character-method
#' @export
#' @seealso \link{read.ml}
#' @note write.ml(LDAModel, character) since 2.1.0
setMethod("write.ml", signature(object = "LDAModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
})
#' Isotonic Regression Model #' Isotonic Regression Model
#' #'
#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). #' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
@ -700,6 +795,8 @@ read.ml <- function(path) {
return(new("GeneralizedLinearRegressionModel", jobj = jobj)) return(new("GeneralizedLinearRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
return(new("KMeansModel", jobj = jobj)) return(new("KMeansModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
return(new("LDAModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
return(new("IsotonicRegressionModel", jobj = jobj)) return(new("IsotonicRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
@ -751,6 +848,71 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
return(new("AFTSurvivalRegressionModel", jobj = jobj)) return(new("AFTSurvivalRegressionModel", jobj = jobj))
}) })
#' Latent Dirichlet Allocation
#'
#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call
#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute
#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new
#' data and \code{write.ml}/\code{read.ml} to save/load fitted models.
#'
#' @param data A SparkDataFrame for training
#' @param features Features column name, default "features". Either libSVM-format column or
#' character-format column is valid.
#' @param k Number of topics, default 10
#' @param maxIter Maximum iterations, default 20
#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online"
#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in
#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05
#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for
#' the prior placed on topic distributions over terms, default -1 to set automatically on the
#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size
#' numeric is accepted.
#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the
#' prior placed on documents distributions over topics (\code{theta}), default -1 to set
#' automatically on the Spark side. Use \code{summary} to retrieve the effective
#' docConcentration. Only 1-size or \code{k}-size numeric is accepted.
#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the
#' parameter if libSVM-format column is used as the features column.
#' @param maxVocabSize maximum vocabulary size, default 1 << 18
#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model
#' @rdname spark.lda
#' @aliases spark.lda,SparkDataFrame-method
#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/}
#' @export
#' @examples
#' \dontrun{
#' text <- read.df("path/to/data", source = "libsvm")
#' model <- spark.lda(data = text, optimizer = "em")
#'
#' # get a summary of the model
#' summary(model)
#'
#' # compute posterior probabilities
#' posterior <- spark.posterior(model, df)
#' showDF(posterior)
#'
#' # compute perplexity
#' perplexity <- spark.perplexity(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#' }
#' @note spark.lda since 2.1.0
setMethod("spark.lda", signature(data = "SparkDataFrame"),
function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"),
subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1,
customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) {
optimizer <- match.arg(optimizer)
jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features,
as.integer(k), as.integer(maxIter), optimizer,
as.numeric(subsamplingRate), topicConcentration,
as.array(docConcentration), as.array(customizedStopWords),
maxVocabSize)
return(new("LDAModel", jobj = jobj))
})
# Returns a summary of the AFT survival regression model produced by spark.survreg, # Returns a summary of the AFT survival regression model produced by spark.survreg,
# similarly to R's summary(). # similarly to R's summary().
@ -891,4 +1053,4 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
setMethod("predict", signature(object = "GaussianMixtureModel"), setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) { function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
}) })

View file

@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", {
unlink(modelPath) unlink(modelPath)
}) })
test_that("spark.lda with libsvm", {
text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
model <- spark.lda(text, optimizer = "em")
stats <- summary(model, 10)
isDistributed <- stats$isDistributed
logLikelihood <- stats$logLikelihood
logPerplexity <- stats$logPerplexity
vocabSize <- stats$vocabSize
topics <- stats$topicTopTerms
weights <- stats$topicTopTermsWeights
vocabulary <- stats$vocabulary
expect_false(isDistributed)
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
expect_equal(vocabSize, 11)
expect_true(is.null(vocabulary))
# Test model save/load
modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_false(stats2$isDistributed)
expect_equal(logLikelihood, stats2$logLikelihood)
expect_equal(logPerplexity, stats2$logPerplexity)
expect_equal(vocabSize, stats2$vocabSize)
expect_equal(vocabulary, stats2$vocabulary)
unlink(modelPath)
})
test_that("spark.lda with text input", {
text <- read.text("data/mllib/sample_lda_data.txt")
model <- spark.lda(text, optimizer = "online", features = "value")
stats <- summary(model)
isDistributed <- stats$isDistributed
logLikelihood <- stats$logLikelihood
logPerplexity <- stats$logPerplexity
vocabSize <- stats$vocabSize
topics <- stats$topicTopTerms
weights <- stats$topicTopTermsWeights
vocabulary <- stats$vocabulary
expect_false(isDistributed)
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
expect_equal(vocabSize, 10)
expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
# Test model save/load
modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
write.ml(model, modelPath)
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
stats2 <- summary(model2)
expect_false(stats2$isDistributed)
expect_equal(logLikelihood, stats2$logLikelihood)
expect_equal(logPerplexity, stats2$logPerplexity)
expect_equal(vocabSize, stats2$vocabSize)
expect_true(all.equal(vocabulary, stats2$vocabulary))
unlink(modelPath)
})
test_that("spark.posterior and spark.perplexity", {
text <- read.text("data/mllib/sample_lda_data.txt")
model <- spark.lda(text, features = "value", k = 3)
# Assert perplexities are equal
stats <- summary(model)
logPerplexity <- spark.perplexity(model, text)
expect_equal(logPerplexity, stats$logPerplexity)
# Assert the sum of every topic distribution is equal to 1
posterior <- spark.posterior(model, text)
local.posterior <- collect(posterior)$topicDistribution
expect_equal(length(local.posterior), sum(unlist(local.posterior)))
})
sparkR.session.stop() sparkR.session.stop()

View file

@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0") @Since("1.6.0")
protected def getModel: OldLDAModel protected def getModel: OldLDAModel
private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray
private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration
/** /**
* The features for LDA should be a [[Vector]] representing the word counts in a document. * The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word). * The vector should be of length vocabSize, with counts for each term (word).

View file

@ -0,0 +1,216 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.r
import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.clustering.{LDA, LDAModel}
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamPair
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType
private[r] class LDAWrapper private (
val pipeline: PipelineModel,
val logLikelihood: Double,
val logPerplexity: Double,
val vocabulary: Array[String]) extends MLWritable {
import LDAWrapper._
private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
private val preprocessor: PipelineModel =
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
def transform(data: Dataset[_]): DataFrame = {
val vec2ary = udf { vec: Vector => vec.toArray }
val outputCol = lda.getTopicDistributionCol
val tempCol = s"${Identifiable.randomUID(outputCol)}"
val preprocessed = preprocessor.transform(data)
lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol))
.withColumn(outputCol, vec2ary(col(tempCol)))
.drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol)
}
def computeLogPerplexity(data: Dataset[_]): Double = {
lda.logPerplexity(preprocessor.transform(data))
}
def topics(maxTermsPerTopic: Int): DataFrame = {
val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic)
if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
topicIndices
} else {
val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
topicIndices
.select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
}
}
lazy val isDistributed: Boolean = lda.isDistributed
lazy val vocabSize: Int = lda.vocabSize
lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration
lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration
override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this)
}
private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}"
val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}"
val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}"
private def getPreStages(
features: String,
customizedStopWords: Array[String],
maxVocabSize: Int): Array[PipelineStage] = {
val tokenizer = new RegexTokenizer()
.setInputCol(features)
.setOutputCol(TOKENIZER_COL)
val stopWordsRemover = new StopWordsRemover()
.setInputCol(TOKENIZER_COL)
.setOutputCol(STOPWORDS_REMOVER_COL)
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
val countVectorizer = new CountVectorizer()
.setVocabSize(maxVocabSize)
.setInputCol(STOPWORDS_REMOVER_COL)
.setOutputCol(COUNT_VECTOR_COL)
Array(tokenizer, stopWordsRemover, countVectorizer)
}
def fit(
data: DataFrame,
features: String,
k: Int,
maxIter: Int,
optimizer: String,
subsamplingRate: Double,
topicConcentration: Double,
docConcentration: Array[Double],
customizedStopWords: Array[String],
maxVocabSize: Int): LDAWrapper = {
val lda = new LDA()
.setK(k)
.setMaxIter(maxIter)
.setSubsamplingRate(subsamplingRate)
val featureSchema = data.schema(features)
val stages = featureSchema.dataType match {
case d: StringType =>
getPreStages(features, customizedStopWords, maxVocabSize) ++
Array(lda.setFeaturesCol(COUNT_VECTOR_COL))
case d: VectorUDT =>
Array(lda.setFeaturesCol(features))
case _ =>
throw new SparkException(
s"Unsupported input features type of ${featureSchema.dataType.typeName}," +
s" only String type and Vector type are supported now.")
}
if (topicConcentration != -1) {
lda.setTopicConcentration(topicConcentration)
} else {
// Auto-set topicConcentration
}
if (docConcentration.length == 1) {
if (docConcentration.head != -1) {
lda.setDocConcentration(docConcentration.head)
} else {
// Auto-set docConcentration
}
} else {
lda.setDocConcentration(docConcentration)
}
val pipeline = new Pipeline().setStages(stages)
val model = pipeline.fit(data)
val vocabulary: Array[String] = featureSchema.dataType match {
case d: StringType =>
val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel]
countVectorModel.vocabulary
case _ => Array.empty[String]
}
val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel]
val preprocessor: PipelineModel =
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1))
val preprocessedData = preprocessor.transform(data)
new LDAWrapper(
model,
ldaModel.logLikelihood(preprocessedData),
ldaModel.logPerplexity(preprocessedData),
vocabulary)
}
override def read: MLReader[LDAWrapper] = new LDAWrapperReader
override def load(path: String): LDAWrapper = super.load(path)
class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadata = ("class" -> instance.getClass.getName) ~
("logLikelihood" -> instance.logLikelihood) ~
("logPerplexity" -> instance.logPerplexity) ~
("vocabulary" -> instance.vocabulary.toList)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
instance.pipeline.save(pipelinePath)
}
}
class LDAWrapperReader extends MLReader[LDAWrapper] {
override def load(path: String): LDAWrapper = {
implicit val format = DefaultFormats
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]
val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray
val pipeline = PipelineModel.load(pipelinePath)
new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary)
}
}
}

View file

@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
GeneralizedLinearRegressionWrapper.load(path) GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" => case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path) KMeansWrapper.load(path)
case "org.apache.spark.ml.r.LDAWrapper" =>
LDAWrapper.load(path)
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
IsotonicRegressionWrapper.load(path) IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" => case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>